xref: /aosp_15_r20/external/pytorch/test/test_type_promotion.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: type promotion"]
2
3from functools import wraps
4import itertools
5import unittest
6
7import torch
8
9from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, make_tensor,
10                                                  TEST_NUMPY, set_default_dtype, torch_to_numpy_dtype_dict,
11                                                  numpy_to_torch_dtype_dict, skipIfTorchDynamo,
12                                                  xfailIfTorchDynamo)
13from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
14                                                        dtypes, onlyCPU, expectedFailureMeta, skipMeta)
15from torch.testing._internal.common_dtype import (
16    all_types_and_complex_and, get_all_math_dtypes, floating_types, get_all_dtypes,
17    float_to_corresponding_complex_type_map,
18)
19
20
21import numpy as np
22import operator
23
24# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
25# sharding on sandcastle. This line silences flake warnings
26load_tests = load_tests
27
28# Not thread-safe decorator that runs the decorated test once with
29# the default dtype being torch.float and again with the default dtype
30# being torch.double.
31def float_double_default_dtype(fn):
32    @wraps(fn)
33    def wrapped_fn(*args, **kwargs):
34        with set_default_dtype(torch.float):
35            fn(*args, **kwargs)
36        with set_default_dtype(torch.double):
37            fn(*args, **kwargs)
38
39    return wrapped_fn
40
41class TestTypePromotion(TestCase):
42
43    # In-place operations don't promote.
44    # `int+float -> float` but `int.add_(float)` is rejected as an error.
45    # Promoting inplace would require re-allocating and copying the memory of the
46    # tensor data, since element size could change.
47    # https://github.com/pytorch/pytorch/issues/127049
48    @xfailIfTorchDynamo
49    @float_double_default_dtype
50    def test_inplace(self, device):
51        int_tensor = torch.ones([4, 4, 4], dtype=torch.int32, device=device)
52
53        self.assertRaisesRegex(RuntimeError, "can't be cast to", lambda: int_tensor.add_(1.5))
54
55        expected = torch.ones([4, 4, 4], dtype=torch.int32, device=device)
56
57        long_tensor = torch.ones([4, 4, 4], dtype=torch.int64, device=device)
58        int_tensor.add_(long_tensor)
59        int_tensor.add_(1)
60        three = expected + 2
61        self.assertEqual(int_tensor, three)
62        self.assertEqual(int_tensor.dtype, torch.int32)
63
64        bool_tensor = torch.tensor([1, 1, 1], dtype=torch.bool, device=device)
65        uint8_tensor = torch.tensor([1, 1, 1], dtype=torch.uint8, device=device)
66        # We treat bool as a separate category, which means uint8 cannot cast to bool.
67        self.assertRaisesRegex(RuntimeError, "can't be cast to", lambda: bool_tensor.add_(uint8_tensor))
68
69        # We allow demotion from signed to unsigned, unlike numpy, because:
70        # * We don't want the performance penalty of inspecting scalar values.
71        # * We don't want 'signed' to be considered a distinct 'category'
72        # in promotion rules.
73        # We don't want signed to be a separate category because if it was,
74        # uint16_tensor + 5 would result in a long_tensor, which is not what we want.
75        int16_tensor = torch.tensor([1, 1, 1], dtype=torch.int16, device=device)
76        uint8_tensor *= int16_tensor
77
78    @float_double_default_dtype
79    def test_unsigned(self, device):
80        dont_promote = torch.ones(3, dtype=torch.uint8, device=device) + 5
81        self.assertEqual(dont_promote.dtype, torch.uint8)
82
83    # some basic examples
84
85    @float_double_default_dtype
86    def test_int_promotion(self, device):
87        a = torch.ones([4, 4, 4], dtype=torch.int32, device=device)
88        b = torch.ones([4, 4, 4], dtype=torch.int64, device=device)
89        c = a + b
90        self.assertEqual(c, b + b)
91        self.assertEqual(c.dtype, torch.int64)
92
93    @float_double_default_dtype
94    def test_float_promotion(self, device):
95        def test_promotion(dtype_float, dtype_double):
96            a = torch.ones([4, 4, 4], dtype=dtype_float, device=device)
97            b = torch.ones([4, 4, 4], dtype=dtype_double, device=device)
98            c = a + b
99            self.assertEqual(c, b + b)
100            self.assertEqual(c.dtype, dtype_double)
101            c = b + a
102            self.assertEqual(c, b + b)
103            self.assertEqual(c.dtype, dtype_double)
104        test_promotion(torch.float, torch.double)
105
106    @float_double_default_dtype
107    def test_complex_promotion(self, device):
108        def test_promotion(dtype_float, dtype_double):
109            a = torch.ones([4, 4, 4], dtype=dtype_float, device=device)
110            b = torch.ones([4, 4, 4], dtype=dtype_double, device=device)
111            c = a + b
112            self.assertEqual(c, b + b)
113            self.assertEqual(c.dtype, dtype_double)
114            c = b + a
115            self.assertEqual(c, b + b)
116            self.assertEqual(c.dtype, dtype_double)
117
118        test_promotion(torch.complex64, torch.complex128)
119
120        a = torch.randn(3, dtype=torch.complex64, device=device)
121        self.assertEqual((a * 5).dtype, torch.complex64)
122        # not a "wrapped number"
123        other = torch.tensor(5.5, dtype=torch.double, device=device)
124        self.assertEqual((a + other).dtype, torch.complex64)
125
126        def make_scalar_tensor(dtype):
127            return make_tensor((), dtype=dtype, device=device)
128
129        def make_1d_tensor(dtype):
130            return make_tensor((3,), dtype=dtype, device=device)
131
132        def complex_scalar_tensor_test(s, t):
133            # As per type promotion rules,
134            # Complex Scalar and Float Tensor -> Complex Tensor with Value type of Float Tensor
135            # Complex Scalar and Integral Tensor -> Complex Tensor with Value type of Complex Scalar
136
137            if t.dtype.is_floating_point:
138                # defaults to return complex64 (for bfloat16)
139                expected_dtype = float_to_corresponding_complex_type_map.get(t.dtype, torch.complex64)
140            else:  # integral tensor
141                if isinstance(s, torch.Tensor):
142                    expected_dtype = s.dtype
143                else:
144                    expected_dtype = float_to_corresponding_complex_type_map[torch.get_default_dtype()]
145            self.assertEqual((s * t).dtype, expected_dtype)
146            self.assertEqual((t * s).dtype, expected_dtype)
147            self.assertEqual(torch.result_type(s, t), expected_dtype)
148            self.assertEqual(torch.result_type(t, s), expected_dtype)
149
150        if torch.device(device).type != 'xla':
151            # chalf is not supported on XLA
152            s = make_scalar_tensor(dtype=torch.chalf)
153            # Same Value type
154            t = make_1d_tensor(dtype=torch.half)
155            # 0-D Tensor X 1-D Tensor
156            complex_scalar_tensor_test(s, t)
157            # Python Scalar X 1-D Tensor
158            complex_scalar_tensor_test(s.item(), t)
159
160            # Higher Value Type
161            t = make_1d_tensor(dtype=torch.float)
162            complex_scalar_tensor_test(s, t)
163            complex_scalar_tensor_test(s.item(), t)
164
165            # Special Case
166            t = make_1d_tensor(dtype=torch.bfloat16)
167            complex_scalar_tensor_test(s, t)
168            complex_scalar_tensor_test(s.item(), t)
169
170            # Integral Tensor
171            t = make_1d_tensor(dtype=torch.long)
172            complex_scalar_tensor_test(s, t)
173            complex_scalar_tensor_test(s.item(), t)
174
175        # CFloat Scalar
176        s = make_scalar_tensor(dtype=torch.cfloat)
177        # Lower Value type than CFloat
178        t = make_1d_tensor(dtype=torch.half)
179        complex_scalar_tensor_test(s, t)
180        complex_scalar_tensor_test(s.item(), t)
181
182        # Higher Value type than CFloat
183        t = make_1d_tensor(dtype=torch.double)
184        complex_scalar_tensor_test(s, t)
185        complex_scalar_tensor_test(s.item(), t)
186
187        # Integral Tensor
188        t = make_1d_tensor(dtype=torch.long)
189        # 0-D Tensor X 1-D Tensor
190        complex_scalar_tensor_test(s, t)
191        # Python Scalar X 1-D Tensor
192        complex_scalar_tensor_test(s.item(), t)
193
194        # CDouble Scalar
195        s = make_scalar_tensor(dtype=torch.cdouble)
196
197        # Lower Value type than CDouble
198        t = make_1d_tensor(dtype=torch.float)
199        complex_scalar_tensor_test(s, t)
200        complex_scalar_tensor_test(s.item(), t)
201
202        # Special Case
203        t = make_1d_tensor(dtype=torch.bfloat16)
204        complex_scalar_tensor_test(s, t)
205        complex_scalar_tensor_test(s.item(), t)
206
207    @float_double_default_dtype
208    def test_complex_scalar_mult_tensor_promotion(self, device):
209        a = 1j * torch.ones(2, device=device)
210        a = a + 1j
211        b = torch.tensor([2j, 2j], device=device)
212        self.assertEqual(a, b)
213        self.assertEqual(a.dtype, b.dtype)
214
215    @float_double_default_dtype
216    def test_add_wrapped(self, device):
217        a = torch.ones([4, 4, 4], dtype=torch.int, device=device)
218        b = 1
219        c = a + b
220        self.assertEqual(c, a + a)
221        self.assertEqual(c.dtype, torch.int)
222
223    @float_double_default_dtype
224    def test_int_to_float(self, device):
225        a = torch.ones([4, 4, 4], dtype=torch.int32, device=device)
226        b = torch.ones([4, 4, 4], dtype=torch.float, device=device)
227        c = a + b
228        self.assertEqual(c.dtype, torch.float32)
229
230    # some examples from:
231    # https://github.com/pytorch/pytorch/issues/9515
232
233    @float_double_default_dtype
234    def test_from_issue(self, device):
235        a = torch.rand(3, dtype=torch.float32, device=device)
236        u = torch.tensor([0, 0, 1], dtype=torch.uint8, device=device)
237        self.assertEqual((a * 5).dtype, torch.float32)
238        self.assertEqual((u + 1).dtype, torch.uint8)
239        self.assertEqual((u + 1000).dtype, torch.uint8)  # integer overflow
240
241        # not a "wrapped number"
242        other = torch.tensor(5.5, dtype=torch.double, device=device)
243
244        self.assertEqual((u + 5.5).dtype, torch.get_default_dtype())
245        self.assertEqual((u + other).dtype, torch.double)
246        # adding a 0-dim tensor to a float doesn't promote to double unless first
247        # type was integral.
248        self.assertEqual((a + other).dtype, torch.float32)
249
250    @float_double_default_dtype
251    def test_half(self, device):
252        half = torch.tensor(5.5, dtype=torch.float16, device=device)
253        self.assertEqual((half + 2.2).dtype, torch.float16)
254        self.assertEqual((half + 100000).dtype, torch.float16)  # inf
255        default_tensor = torch.tensor(100000.0, device=device)
256        self.assertEqual((half + default_tensor).dtype, torch.get_default_dtype())
257
258    def test_bfloat16(self, device):
259        # with scalar
260        bf = torch.tensor(5.5, dtype=torch.bfloat16, device=device)
261        for scalar in (2.2, 5, 100000):   # bf + 100000 is inf
262            self.assertEqual((bf + scalar).dtype, torch.bfloat16)
263            self.assertEqual(scalar + bf, bf + scalar)
264
265        for scalar in (complex(1, 1), complex(-2, 0), complex(0, -3)):
266            self.assertEqual((bf + scalar).dtype, torch.cfloat)
267            self.assertEqual(bf + scalar, scalar + bf)
268
269        # with tensor
270        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool):
271            t = torch.tensor(1, dtype=dtype, device=device)
272            self.assertEqual(bf + t, t + bf)
273            if dtype in (torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble):
274                # Handles bfloat16 x float16 -> float32 promotion
275                expected_dtype = dtype if dtype != torch.half else torch.float32
276            elif dtype is torch.chalf:
277                expected_dtype = torch.cfloat
278            elif dtype in (torch.bool, torch.uint8,
279                           torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16):
280                expected_dtype = torch.bfloat16
281            else:
282                raise AssertionError(f'Missing dtype {dtype} not tested.')
283
284            self.assertEqual(torch.promote_types(dtype, torch.bfloat16), expected_dtype)
285            self.assertEqual(torch.promote_types(torch.bfloat16, dtype), expected_dtype)
286            self.assertEqual((bf + t).dtype, expected_dtype)
287
288    @onlyNativeDeviceTypes
289    def test_complex_half(self, device):
290        # with scalar
291        chalf = torch.tensor(5.5, dtype=torch.chalf, device=device)
292        for scalar in (2.2, 5, 100000):   # chalf + 100000 is inf
293            self.assertEqual((chalf * scalar).dtype, torch.chalf)
294            self.assertEqual(scalar * chalf, chalf * scalar)
295
296        for scalar in (complex(1, 1), complex(-2, 0), complex(0, -3)):
297            self.assertEqual((chalf * scalar).dtype, torch.chalf)
298            self.assertEqual(chalf * scalar, scalar * chalf)
299
300        # with tensor
301        dtypes = all_types_and_complex_and(torch.chalf, torch.half, torch.bfloat16, torch.bool)
302        for dtype in dtypes:
303            t = torch.tensor(1, dtype=dtype, device=device)
304            self.assertEqual(chalf * t, t * chalf)
305            if dtype in (torch.float16, torch.chalf):
306                expected_dtype = torch.chalf
307            elif dtype in (torch.float, torch.double, torch.bfloat16):
308                expected_dtype = torch.cdouble if dtype is torch.double else torch.cfloat
309            elif dtype in (torch.cfloat, torch.cdouble):
310                expected_dtype = dtype
311            elif dtype in (torch.bool, torch.uint8,
312                           torch.int8, torch.int16, torch.int32, torch.int64):
313                expected_dtype = torch.chalf
314            else:
315                raise AssertionError(f'Missing dtype {dtype} not tested.')
316
317            self.assertEqual(torch.promote_types(dtype, torch.chalf), expected_dtype)
318            self.assertEqual(torch.promote_types(torch.chalf, dtype), expected_dtype)
319            self.assertEqual((chalf * t).dtype, expected_dtype)
320
321    @float_double_default_dtype
322    def test_alternate_result(self, device):
323        x = torch.tensor([1, 1, 1, 1], dtype=torch.float, device=device)
324        o = torch.tensor([0, 0, 0, 0], dtype=torch.long, device=device)
325        self.assertRaisesRegex(RuntimeError,
326                               "can't be cast to",
327                               lambda: torch.add(x, x, out=o))
328        d = torch.tensor([1, 1, 1, 1], dtype=torch.double, device=device)
329        torch.add(x, x, out=d)
330        self.assertEqual(d.dtype, torch.double)
331        x = x.to(torch.double)
332        self.assertEqual(x + x, d)
333
334    @float_double_default_dtype
335    def test_mixed_type_backward(self, device):
336        f = torch.ones([3, 3], dtype=torch.float, requires_grad=True, device=device)
337        ten = torch.tensor([10.], dtype=torch.double, device=device)
338        tens = f * ten
339        s = (tens + 2).sum()
340        s.backward()
341        expected = f.grad.to(torch.double)
342        self.assertEqual(tens, expected)
343
344        # If we don't convert the returned grad_input to the actual input type
345        # we get an error like:
346        # RuntimeError: Function SubBackward0 returned an invalid gradient at index 0 - expected type \
347        # torch.FloatTensor but got torch.DoubleTensor
348        f_dtypes = [torch.float, torch.double]
349        if self.device_type == 'cuda':
350            f_dtypes = f_dtypes + [torch.half]
351        i_dtypes = [torch.int, torch.long]
352        for func in [torch.add, torch.sub, torch.rsub, torch.mul, torch.div]:
353            for dtype1, dtype2 in itertools.product(f_dtypes, f_dtypes + i_dtypes):
354                x = torch.ones(10, requires_grad=True, dtype=dtype1, device=device)
355                y = torch.ones(10, dtype=dtype2, device=device)
356                func(x, y).sum().backward()
357
358    def _get_test_tensor(self, device, dtype, remove_zeros=False):
359        shape = [5, 5, 5]
360        if dtype == torch.bool:
361            tensor = torch.randint(int(remove_zeros), 2, shape, device=device, dtype=dtype)
362        elif dtype.is_floating_point or dtype.is_complex:
363            # "_th_normal_ not supported on CPUType for Half" so simpler create and convert
364            tensor = torch.randn(shape, device=device)
365            tensor = tensor.to(dtype)
366            if remove_zeros:
367                tensor[torch.abs(tensor) < 0.05] = 5
368        else:
369            tensor = torch.randint(-5 if dtype.is_signed else 0, 10, shape, device=device, dtype=dtype)
370            if remove_zeros:
371                tensor[tensor == 0] = 5
372        return tensor
373
374    # verifies that torch.<op>(first, second) is the same as
375    # torch.<op>(first.to(common_dtype), second.to(common_dtype)) in cases where that should hold.
376    @float_double_default_dtype
377    def test_many_promotions(self, device):
378        # Can also include half on CPU in cases where it will be promoted to a
379        # supported dtype
380        dtypes1 = get_all_math_dtypes('cuda')
381        dtypes2 = get_all_math_dtypes(device)
382        ops = [torch.add, torch.sub, torch.mul, torch.div, torch.rsub]
383        for dt1, dt2 in itertools.product(dtypes1, dtypes2):
384            for op, non_contiguous in itertools.product(ops, [True, False]):
385                common_dtype = torch.promote_types(dt1, dt2)
386                if common_dtype == torch.half and self.device_type == 'cpu':
387                    continue
388                if op == torch.sub and common_dtype != torch.bool:
389                    # Subtraction, the `-` operator, with a bool tensor is not supported.
390                    continue
391                first = self._get_test_tensor(device, dt1)
392                second = self._get_test_tensor(device, dt2, op == torch.div)
393                # test ops with non-contiguous tensors
394                if non_contiguous:
395                    first = first.transpose(0, 2)
396                    second = second.transpose(2, 1)
397                    self.assertNotEqual(first.stride(), second.stride(),
398                                        msg="some non-contiguous issues could be missed if tensors have same strides")
399
400                self.assertEqual(not first.is_contiguous(), non_contiguous)
401                self.assertEqual(not second.is_contiguous(), non_contiguous)
402                result = op(first, second)
403                expected = op(first.to(common_dtype), second.to(common_dtype))
404                self.assertEqual(result.dtype, expected.dtype, msg=f'{op.__name__} with {dt1}, {dt2}')
405                self.assertEqual(result, expected, msg=f'{op.__name__} with {dt1}, {dt2}')
406
407    @float_double_default_dtype
408    def test_non_promoting_ops(self, device):
409        x = torch.ones(4, dtype=torch.double, device=device)
410        with self.assertRaises(RuntimeError):
411            torch.lerp(x, torch.ones(4, dtype=torch.float, device=device), 1)
412
413    @float_double_default_dtype
414    def test_alpha_mismatch(self, device):
415        x = torch.ones(4, dtype=torch.int, device=device)
416        err = 'alpha must not be'
417        self.assertRaisesRegex(RuntimeError, err,
418                               lambda: torch.add(x, x, alpha=1.1))
419        x = x.to(torch.bool)
420        self.assertRaisesRegex(RuntimeError, err,
421                               lambda: torch.add(x, x, alpha=1.1))
422        self.assertEqual(x + x, torch.add(x, x, alpha=True))
423
424    @float_double_default_dtype
425    def test_booleans(self, device):
426        onedim = torch.tensor([True], device=device)
427
428        self.assertEqual(onedim + onedim, onedim)
429        self.assertEqual(onedim + True, onedim)
430        self.assertEqual(torch.add(True, True), True)
431        self.assertEqual(torch.add(False, False), False)
432        self.assertEqual(torch.add(False, True), True)
433
434        self.assertRaisesRegex(RuntimeError, "Boolean alpha only supported",
435                               lambda: torch.add(1, 1, alpha=True))
436        self.assertEqual(torch.add(torch.tensor(True, device=device),
437                         torch.tensor(True, device=device), True),
438                         torch.tensor(True, device=device))
439
440    @skipIfTorchDynamo("Not a TorchDynamo suitable test")
441    @float_double_default_dtype
442    def test_create_bool_tensors(self, device):
443        expected = torch.tensor([0], dtype=torch.int64, device=device)
444        self.assertEqual(torch.arange(False, True, device=device), expected)
445        self.assertEqual(torch.arange(True, device=device), expected)
446        expected = torch.tensor([0, 0.5], dtype=torch.get_default_dtype(), device=device)
447        self.assertEqual(torch.arange(False, True, 0.5, device=device), expected)
448        expected = torch.ones(0, dtype=torch.int64, device=device)
449        self.assertEqual(torch.arange(False, False, device=device), expected)
450
451        bool_tensor_lin = torch.linspace(False, True, steps=100, device=device)
452        int_tensor_lin = torch.linspace(0, 1, steps=100, device=device)
453        self.assertEqual(bool_tensor_lin, int_tensor_lin)
454        bool_tensor_log = torch.linspace(False, True, steps=100, device=device)
455        int_tensor_log = torch.linspace(0, 1, steps=100, device=device)
456        self.assertEqual(bool_tensor_log, int_tensor_log)
457
458        # this seems like odd behavior but ints also create float tensors, numpy doesn't have this function.
459        self.assertEqual(torch.scalar_tensor(False, device=device), torch.tensor(0., device=device))
460
461    @dtypes(*itertools.product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool),
462                               all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)))
463    def test_result_type(self, device, dtypes):
464        "Test result_type for tensor vs tensor and scalar vs scalar."
465
466        def _get_dtype(x):
467            "Get the dtype of x if x is a tensor. If x is a scalar, get its corresponding dtype if it were a tensor."
468            if torch.is_tensor(x):
469                return x.dtype
470            elif isinstance(x, bool):
471                return torch.bool
472            elif isinstance(x, int):
473                return torch.int64
474            elif isinstance(x, float):
475                return torch.float32
476            elif isinstance(x, complex):
477                return torch.complex64
478            else:
479                raise AssertionError(f"Unknown type {x}")
480
481        # tensor against tensor
482        a_tensor = torch.tensor((0, 1), device=device, dtype=dtypes[0])
483        a_single_tensor = torch.tensor(1, device=device, dtype=dtypes[0])
484        a_scalar = a_single_tensor.item()
485        b_tensor = torch.tensor((1, 0), device=device, dtype=dtypes[1])
486        b_single_tensor = torch.tensor(1, device=device, dtype=dtypes[1])
487        b_scalar = b_single_tensor.item()
488        combo = ((a_tensor, a_single_tensor, a_scalar), (b_tensor, b_single_tensor, b_scalar))
489        for a, b in itertools.product(*combo):
490            dtype_a = _get_dtype(a)
491            dtype_b = _get_dtype(b)
492            try:
493                result = a + b
494            except RuntimeError:
495                with self.assertRaises(RuntimeError):
496                    torch.promote_types(dtype_a, dtype_b)
497                with self.assertRaises(RuntimeError):
498                    torch.result_type(a, b)
499            else:
500                dtype_res = _get_dtype(result)
501                if a is a_scalar and b is b_scalar and dtype_a == torch.bool and dtype_b == torch.bool:
502                    # special case: in Python, True + True is an integer
503                    self.assertEqual(dtype_res, torch.int64, f"a == {a}, b == {b}")
504                else:
505                    self.assertEqual(dtype_res, torch.result_type(a, b), f"a == {a}, b == {b}")
506                if a is a_scalar and b is b_scalar:  # Python internal type determination is good enough in this case
507                    continue
508                if any(a is a0 and b is b0 for a0, b0 in zip(*combo)):  # a and b belong to the same class
509                    self.assertEqual(dtype_res, torch.promote_types(dtype_a, dtype_b), f"a == {a}, b == {b}")
510
511    # Spot check some result type for tensor against scalar (including single-element tensor).
512    @float_double_default_dtype
513    def test_result_type_tensor_vs_scalar(self, device):
514        def _test_spot(a, b, res_dtype):
515            self.assertEqual(torch.result_type(a, b), res_dtype)
516            self.assertEqual(torch.result_type(b, a), res_dtype)
517
518        _test_spot(torch.tensor([1, 2], dtype=torch.half, device=device),
519                   torch.tensor(1, dtype=torch.long, device=device), torch.half)
520        _test_spot(torch.tensor(1, dtype=torch.float, device=device),
521                   torch.tensor([1, 2], dtype=torch.double, device=device), torch.double)
522        _test_spot(torch.tensor(1, dtype=torch.int, device=device), 1, torch.int)
523        _test_spot(torch.tensor(1, device=device), 1., torch.get_default_dtype())
524        _test_spot(torch.tensor(1, dtype=torch.long, device=device),
525                   torch.tensor([1, 1], dtype=torch.int, device=device), torch.int)
526        _test_spot(torch.tensor([1., 1.], dtype=torch.float, device=device), 1., torch.float)
527        _test_spot(torch.tensor([1., 1.], dtype=torch.complex64, device=device),
528                   torch.tensor(1., dtype=torch.complex128, device=device), torch.complex64)
529        _test_spot(torch.tensor([1., 1.], dtype=torch.complex128, device=device),
530                   torch.tensor(1., dtype=torch.complex64, device=device), torch.complex128)
531        _test_spot(torch.tensor([1, 1], dtype=torch.bool, device=device), 1., torch.get_default_dtype())
532
533    @float_double_default_dtype
534    def test_can_cast(self, device):
535        self.assertTrue(torch.can_cast(torch.double, torch.float))
536        self.assertFalse(torch.can_cast(torch.float, torch.int))
537
538    @float_double_default_dtype
539    def test_comparison_ops_with_type_promotion(self, device):
540        value_for_type = {
541            torch.uint8: (1 << 5),
542            torch.int8: (1 << 5),
543            torch.int16: (1 << 10),
544            torch.int32: (1 << 20),
545            torch.int64: (1 << 35),
546            torch.float16: (1 << 10),
547            torch.float32: (1 << 20),
548            torch.float64: (1 << 35),
549            torch.complex64: (1 << 20),
550            torch.complex128: (1 << 35)
551        }
552        comparison_ops = [
553            dict(
554                name="lt",
555                out_op=lambda x, y, d: torch.lt(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
556                ret_op=lambda x, y: torch.lt(x, y),
557                compare_op=operator.lt,
558            ),
559            dict(
560                name="le",
561                out_op=lambda x, y, d: torch.le(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
562                ret_op=lambda x, y: torch.le(x, y),
563                compare_op=operator.le,
564            ),
565            dict(
566                name="gt",
567                out_op=lambda x, y, d: torch.gt(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
568                ret_op=lambda x, y: torch.gt(x, y),
569                compare_op=operator.gt,
570            ),
571            dict(
572                name="ge",
573                out_op=lambda x, y, d: torch.ge(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
574                ret_op=lambda x, y: torch.ge(x, y),
575                compare_op=operator.ge,
576            ),
577            dict(
578                name="eq",
579                out_op=lambda x, y, d: torch.eq(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
580                ret_op=lambda x, y: torch.eq(x, y),
581                compare_op=operator.eq,
582            ),
583            dict(
584                name="ne",
585                out_op=lambda x, y, d: torch.ne(x, y, out=torch.empty(0, dtype=torch.bool, device=d)),
586                ret_op=lambda x, y: torch.ne(x, y),
587                compare_op=operator.ne,
588            ),
589        ]
590        for op in comparison_ops:
591            for dt1 in get_all_math_dtypes(device):
592                for dt2 in get_all_math_dtypes(device):
593                    if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"):
594                        continue
595                    val1 = value_for_type[dt1]
596                    val2 = value_for_type[dt2]
597                    t1 = torch.tensor([val1], dtype=dt1, device=device)
598                    t2 = torch.tensor([val2], dtype=dt2, device=device)
599                    expected = torch.tensor([op["compare_op"](val1, val2)], dtype=torch.bool)
600
601                    out_res = op["out_op"](t1, t2, device)
602                    self.assertEqual(out_res, expected)
603                    self.assertTrue(out_res.dtype == torch.bool)
604                    self.assertTrue(t1.dtype == dt1)
605                    self.assertTrue(t2.dtype == dt2)
606
607                    out_res = op["ret_op"](t1, t2)
608                    self.assertEqual(out_res, expected)
609                    self.assertTrue(out_res.dtype == torch.bool)
610                    self.assertTrue(t1.dtype == dt1)
611                    self.assertTrue(t2.dtype == dt2)
612
613                    # test that comparing a zero dim tensor with another zero dim tensor has type promotion behavior
614                    t1 = torch.tensor(val1, dtype=dt1, device=device)
615                    t2 = torch.tensor(val2, dtype=dt2, device=device)
616                    expected = torch.tensor(op["compare_op"](val1, val2), dtype=torch.bool)
617
618                    out_res = op["out_op"](t1, t2, device)
619                    self.assertEqual(out_res, expected)
620                    self.assertTrue(out_res.dtype == torch.bool)
621                    self.assertTrue(t1.dtype == dt1)
622                    self.assertTrue(t2.dtype == dt2)
623
624                    out_res = op["ret_op"](t1, t2)
625                    self.assertEqual(out_res, expected)
626                    self.assertTrue(out_res.dtype == torch.bool)
627                    self.assertTrue(t1.dtype == dt1)
628                    self.assertTrue(t2.dtype == dt2)
629
630    # XLA tests fail for self.assertRaises for complex dtypes
631    @onlyNativeDeviceTypes
632    def test_complex_assertraises(self, device):
633        comparison_ops = [
634            dict(name="lt", compare_op=operator.lt, ),
635            dict(name="le", compare_op=operator.le, ),
636            dict(name="gt", compare_op=operator.gt, ),
637            dict(name="ge", compare_op=operator.ge, ),
638            dict(name="eq", compare_op=operator.eq, ),
639            dict(name="ne", compare_op=operator.ne, ),
640        ]
641        for op in comparison_ops:
642            is_cuda = torch.device(device).type == 'cuda'
643            dtypes = get_all_dtypes(include_half=is_cuda,
644                                    include_bfloat16=False, include_bool=False,
645                                    include_complex32=True)
646
647            for dt1, dt2 in itertools.product(dtypes, dtypes):
648                if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"):
649                    u = torch.tensor([1], dtype=dt1, device=device)
650                    v = torch.tensor([2], dtype=dt2, device=device)
651                    self.assertRaises(RuntimeError, lambda: torch.tensor([op["compare_op"](u, v)], dtype=torch.bool))
652
653    @float_double_default_dtype
654    def test_lt_with_type_promotion(self, device):
655        for dt in get_all_math_dtypes(device):
656            x = torch.tensor([0], dtype=dt, device=device)
657            expected = torch.tensor([True], dtype=torch.bool, device=device)
658
659            if dt.is_complex:
660                continue
661
662            actual = x < 0.5
663            self.assertTrue(actual, expected)
664            self.assertTrue(actual.dtype == torch.bool)
665
666            actual = x < torch.tensor(0.5, device=device)
667            self.assertTrue(actual, expected)
668            self.assertTrue(actual.dtype == torch.bool)
669
670            x = torch.tensor(0, dtype=dt, device=device)
671            expected = torch.tensor(True, dtype=torch.bool, device=device)
672            actual = x < 0.5
673            self.assertTrue(actual, expected)
674            self.assertTrue(actual.dtype == torch.bool)
675
676            actual = x < torch.tensor(0.5, device=device)
677            self.assertTrue(actual, expected)
678            self.assertTrue(actual.dtype == torch.bool)
679
680    @float_double_default_dtype
681    def test_promote_types(self, device):
682        self.assertEqual(torch.promote_types(torch.float, torch.int), torch.float)
683        self.assertEqual(torch.promote_types(torch.float, torch.double), torch.double)
684        self.assertEqual(torch.promote_types(torch.int, torch.uint8), torch.int)
685        with self.assertRaisesRegex(RuntimeError, "Promotion for Float8 Types is not supported"):
686            self.assertEqual(torch.promote_types(torch.float8_e5m2, torch.float), torch.float)
687        with self.assertRaisesRegex(RuntimeError, "Promotion for Float8 Types is not supported"):
688            self.assertEqual(torch.promote_types(torch.float, torch.float8_e4m3fn), torch.float)
689
690    @float_double_default_dtype
691    def test_promote_self(self, device):
692        for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf, torch.bool,
693                                               torch.float8_e5m2, torch.float8_e4m3fn):
694            self.assertEqual(torch.promote_types(dtype, dtype), dtype)
695
696    @expectedFailureMeta
697    @float_double_default_dtype
698    def test_indexing_fail(self, device):
699        # https://github.com/pytorch/pytorch/issues/28010
700        a = torch.ones(5, 2, dtype=torch.double, device=device)
701        b = torch.zeros(5, dtype=torch.int, device=device)
702        with self.assertRaises(RuntimeError):
703            a[:, [1]] = b.unsqueeze(-1)
704
705    @float_double_default_dtype
706    def test_indexing(self, device):
707        x = torch.ones(5, 2, dtype=torch.double, device=device)
708        y = torch.zeros(5, dtype=torch.double, device=device)
709        x[:, [1]] = y.unsqueeze(-1)
710        expected = torch.tensor([(1, 0), (1, 0), (1, 0), (1, 0), (1, 0)], dtype=torch.double, device=device)
711        self.assertEqual(x, expected)
712
713
714        # https://github.com/pytorch/pytorch/issues/27824
715        tmp = torch.ones(9, 9, dtype=torch.float, device=device)
716        mask = torch.ones(10, 10, dtype=torch.uint8, device=device)
717        result = tmp + mask[1:, 1:]
718        expected = torch.full([9, 9], 2., dtype=torch.float, device=device).fill_(2.)
719        self.assertEqual(result, expected)
720
721    @float_double_default_dtype
722    def test_transpose(self, device):
723        # https://github.com/pytorch/pytorch/issues/28502
724        a = torch.tensor([[True, True], [False, True]], device=device)
725        self.assertEqual(a.t() == 0, a.t() == False)  # noqa: E712
726
727    @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
728    @float_double_default_dtype
729    def test_div_promotion(self, device, dtype):
730        for op in (torch.div, torch.true_divide):
731            dividend = (torch.randn(5, device=device) * 100).to(dtype)
732            divisor = torch.arange(1, 6, device=device).to(dtype)
733
734            # Tests tensor/tensor division
735            casting_result = dividend.to(torch.get_default_dtype()) / divisor.to(torch.get_default_dtype())
736            self.assertEqual(casting_result, op(dividend, divisor))
737
738            # Tests tensor/scalar division
739            casting_result = dividend.to(torch.get_default_dtype()) / 2
740            self.assertEqual(casting_result, op(dividend, 2.))
741
742    @onlyNativeDeviceTypes
743    @dtypes(torch.float, torch.double,
744            torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
745    def test_div_promotion_out(self, device, dtype):
746        for op in (torch.div, torch.true_divide):
747            dividend = (torch.randn(5, device=device) * 100).to(dtype)
748            divisor = torch.arange(1, 6, device=device).to(dtype)
749
750            # Tests that requests for an integer quotient fail
751            if not dtype.is_floating_point:
752                integral_quotient = torch.empty(5, device=device, dtype=dtype)
753                with self.assertRaises(RuntimeError):
754                    op(dividend, divisor, out=integral_quotient)
755                with self.assertRaises(RuntimeError):
756                    op(dividend, 2, out=integral_quotient)
757            else:
758                # Tests that requests for a floating quotient succeed
759                floating_quotient = torch.empty(5, device=device, dtype=dtype)
760                div_result = dividend / divisor
761                self.assertEqual(div_result,
762                                 op(dividend, divisor, out=floating_quotient))
763                self.assertEqual(dividend / 2,
764                                 op(dividend, 2, out=floating_quotient))
765
766    @dtypes(torch.float, torch.double,
767            torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
768    def test_div_promotion_inplace(self, device, dtype):
769        for op in (torch.Tensor.div_, torch.Tensor.true_divide_):
770            dividend = (torch.randn(5, device=device) * 100).to(dtype)
771            divisor = torch.arange(1, 6, device=device).to(dtype)
772
773            # Tests that requests for an integer quotient fail
774            if not dtype.is_floating_point:
775                with self.assertRaises(RuntimeError):
776                    op(dividend, divisor)
777                with self.assertRaises(RuntimeError):
778                    op(dividend, 2)
779            else:
780                # Tests that requests for a floating quotient succeed
781                div_result = dividend.clone().div_(divisor)
782                self.assertEqual(div_result, op(dividend.clone(), divisor))
783                self.assertEqual(dividend.clone().div_(2), op(dividend.clone(), 2))
784
785    def _test_sparse_op_input_tensors(self, device, dtype, coalesced, zeros=True):
786        t = self._get_test_tensor(device, dtype, not zeros)
787        if zeros and dtype != torch.bool:
788            # ensure sparsity. Bool should already have sufficient sparsity.
789            mask = self._get_test_tensor(device, torch.bool)
790            t = t * mask
791
792        if coalesced:
793            s = t.to_sparse()
794        else:
795            s = t.to_sparse()
796            indices = torch.cat((s.indices(), s.indices()), 1)
797            values = torch.cat((s.values(), s.values()), 0)
798            s = torch.sparse_coo_tensor(indices=indices, values=values, size=s.size(), dtype=dtype, device=device)
799            t = s.to_dense()
800        self.assertEqual(s.is_coalesced(), coalesced)
801        self.assertEqual(s.dtype, dtype)
802        self.assertEqual(t.dtype, s.dtype)
803        return t, s
804
805    def _get_precision(self, dtype, coalesced):
806        if dtype == torch.half and not coalesced:
807            # very low precision for uncoalesced float16 sparse tensors since
808            # ops like (s1 + s2).to_dense() will add four low-precision
809            # floating point values.
810            return 5e-2
811        if dtype == torch.half:
812            return 1e-3
813        # uses default
814        return None
815
816    def _test_sparse_op(self, op_name, inplace, dtype1, dtype2, device, coalesced):
817        if dtype1.is_complex or dtype2.is_complex:
818            return
819
820        suffix = '_' if inplace else ''
821        err = f"{'  coalesced' if coalesced else 'uncoalesced'} {op_name + suffix}({dtype1}, {dtype2})"
822
823        def op(t1, t2, suf=None):
824            suf = suffix if suf is None else suf
825            return getattr(t1, op_name + suf)(t2)
826
827        add_sub = op_name == 'add' or op_name == 'sub'
828
829        (dense1, sparse1) = self._test_sparse_op_input_tensors(device, dtype1, coalesced)
830        (dense2, sparse2) = self._test_sparse_op_input_tensors(device, dtype2, coalesced, op_name != 'div')
831
832        common_dtype = torch.result_type(dense1, dense2)
833        if self.device_type == 'cpu' and common_dtype == torch.half:
834            self.assertRaises(RuntimeError, lambda: op(s1, d2))
835
836        # Skip inplace tests that would fail due to inability to cast to the output type.
837        # Some of these would also raise errors due to not being a supported op.
838        if inplace and not torch.can_cast(common_dtype, dtype1):
839            self.assertRaises(RuntimeError, lambda: op(dense1, sparse2))
840            self.assertRaises(RuntimeError, lambda: op(sparse1, sparse2))
841            self.assertRaises(RuntimeError, lambda: op(sparse1, dense2))
842            return
843
844        expected = op(dense1.clone(), dense2)
845        precision = self._get_precision(expected.dtype, coalesced)
846        rtol = None if precision is None else 0
847        test_tensors = [expected, dense1, sparse1, dense2, sparse2]
848        e, d1, s1, d2, s2 = [x.clone() for x in test_tensors] if inplace else test_tensors
849
850        # Test op(sparse, sparse)
851        if op_name != 'div':
852            sparse = op(s1, s2)
853            self.assertEqual(sparse.dtype, e.dtype)
854            self.assertEqual(e, sparse.to_dense(), atol=precision, rtol=rtol, msg=err)
855        else:
856            # sparse division only supports division by a scalar
857            self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense())
858
859        # Test op(dense, sparse)
860        if add_sub or op_name == 'mul':
861            if inplace:
862                e, d1, s1, d2, s2 = (x.clone() for x in test_tensors)
863            dense_sparse = op(d1, s2)
864            dense_sparse = dense_sparse.to_dense() if dense_sparse.is_sparse else dense_sparse
865            self.assertEqual(e, dense_sparse, atol=precision, rtol=rtol, msg=err)
866        else:
867            # sparse division only supports division by a scalar
868            # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz'
869            self.assertRaises(RuntimeError, lambda: op(d1, s2))
870
871        # Test op(sparse, dense) not supported for all ops but 'mul'.
872        # add(sparse, dense) is not supported. Use add(dense, sparse) instead.
873        # sparse division only supports division by a scalar
874        if op_name != 'mul':
875            self.assertRaises(RuntimeError, lambda: op(s1, d2))
876        else:
877            # No type promotions for inplace operations, hence suf=''
878            op(s1, d2, suf='')
879
880        # Test op(sparse, scalar)
881        if not add_sub and not (self.device_type == 'cpu' and dtype1 == torch.half):
882            if inplace:
883                e, d1, s1, d2, s2 = (x.clone() for x in test_tensors)
884            scalar = d2.view(d2.numel())[0].item()
885
886            sparse = op(s1, scalar)
887            dense_scalar = op(d1, scalar)
888            self.assertEqual(sparse.dtype, dense_scalar.dtype)
889            self.assertEqual(dense_scalar, sparse.to_dense(), atol=precision, rtol=rtol, msg=err)
890        else:
891            # add(sparse, dense) is not supported. Use add(dense, sparse) instead.
892            # "mul_cpu" / "div_cpu" not implemented for 'Half'
893            self.assertRaises(RuntimeError, lambda: op(s1, d2.view(d2.numel())[0].item()))
894
895    def _run_all_tests_for_sparse_op(self, op_name, device, dtypes):
896        for dtype1, dtype2 in itertools.product(dtypes, dtypes):
897            for inplace, coalesced in itertools.product([True, False], [True, False]):
898                self._test_sparse_op(op_name, inplace, dtype1, dtype2, device, coalesced)
899
900    @onlyNativeDeviceTypes
901    def test_sparse_add(self, device):
902        self._run_all_tests_for_sparse_op('add', device,
903                                          dtypes=get_all_math_dtypes(device))
904
905    @onlyNativeDeviceTypes
906    def test_sparse_mul(self, device):
907        self._run_all_tests_for_sparse_op('mul', device,
908                                          dtypes=get_all_math_dtypes(device))
909
910    @onlyNativeDeviceTypes
911    def test_sparse_div(self, device):
912        self._run_all_tests_for_sparse_op('div', device,
913                                          dtypes=(torch.float32, torch.float64,
914                                                  torch.complex64, torch.complex128))
915
916    @onlyNativeDeviceTypes
917    def test_sparse_sub(self, device):
918        self._run_all_tests_for_sparse_op('sub', device,
919                                          dtypes=get_all_math_dtypes(device))
920
921    @onlyNativeDeviceTypes
922    @dtypes(torch.bool, torch.short, torch.uint8, torch.int, torch.long)
923    @float_double_default_dtype
924    def test_sparse_div_promotion(self, device, dtype):
925        for op in (torch.div, torch.true_divide):
926            dividend = torch.randn(5, device=device).to(dtype)
927            divisor = 2
928            dividend_sparse = dividend.to_sparse()
929            casting_result = dividend.to(torch.get_default_dtype()) / 2
930            self.assertEqual(casting_result, op(dividend_sparse, 2).to_dense())
931
932    @onlyNativeDeviceTypes
933    @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)
934    def test_integer_addcdiv_deprecated(self, device, dtype):
935        t = torch.tensor(1, device=device, dtype=dtype)
936
937        with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported.+'):
938            torch.addcdiv(t, t, t)
939        with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported.+'):
940            torch.addcdiv(t, t, t, out=t)
941        with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported+'):
942            t.addcdiv_(t, t)
943
944    @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
945    @float_double_default_dtype
946    @onlyCPU
947    # NB: skip uint16,32,64 as PyTorch doesn't implement promotion for them
948    @dtypes(*list(itertools.product(
949        set(numpy_to_torch_dtype_dict.values()) - {torch.uint16, torch.uint32, torch.uint64},
950        set(numpy_to_torch_dtype_dict.values()) - {torch.uint16, torch.uint32, torch.uint64})))
951    def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
952        import operator
953        np_type = torch_to_numpy_dtype_dict[dtypes[0]]
954        torch_type = dtypes[1]
955
956        t = torch.tensor((1,), device=device, dtype=torch_type)
957        a = np.array((1,), dtype=np_type)
958        a_as_t = torch.from_numpy(a).to(device=device)
959
960        for np_first in (True, False):
961            for op in (operator.add, torch.add):
962
963                # Acquires results of binary ufunc type promotion.
964                try:
965                    actual = op(a, t) if np_first else op(t, a)
966                except Exception as e:
967                    actual = e
968
969                try:
970                    expected = op(a_as_t, t) if np_first else op(t, a_as_t)
971                except Exception as e:
972                    expected = e
973
974                same_result = (type(expected) == type(actual)) and expected == actual
975
976                # Note: An "undesired failure," as opposed to an "expected failure"
977                # is both expected (we know the test will fail) and
978                # undesirable (if PyTorch was working properly the test would
979                # not fail). This test is affected by three issues (see below)
980                # that will cause undesired failures. It detects when these
981                # issues will occur and updates this bool accordingly.
982                undesired_failure = False
983
984                # A NumPy array as the first argument to the plus operator
985                # or as any argument to torch.add is not working as
986                # intended.
987                # See https://github.com/pytorch/pytorch/issues/36363.
988                if np_first and op is operator.add:
989                    undesired_failure = True
990                if op is torch.add:
991                    undesired_failure = True
992
993                # Expects the same result if undesired_failure is false
994                # and a different result otherwise.
995                # Note: These cases prettyprint the failing inputs to make
996                # debugging test failures easier.
997                if undesired_failure and same_result:
998                    msg = (
999                        f"Failure: {actual} == {expected}. torch type was {torch_type}. "
1000                        f"NumPy type was {np_type}. np_first is {np_first} default type is "
1001                        f"{torch.get_default_dtype()}."
1002                    )
1003                    self.fail(msg)
1004
1005                if not undesired_failure and not same_result:
1006                    msg = (
1007                        f"Failure: {actual} != {expected}. torch type was {torch_type}. "
1008                        f"NumPy type was {np_type}. np_first is {np_first} default type is "
1009                        f"{torch.get_default_dtype()}."
1010                    )
1011                    self.fail(msg)
1012
1013
1014    @onlyNativeDeviceTypes
1015    def test_cat_different_dtypes(self, device):
1016        dtypes = all_types_and_complex_and(torch.half, torch.bool)
1017        for x_dtype, y_dtype in itertools.product(dtypes, dtypes):
1018            x_vals, y_vals = [1, 2, 3], [4, 5, 6]
1019
1020            x = torch.tensor(x_vals, device=device, dtype=x_dtype)
1021            y = torch.tensor(y_vals, device=device, dtype=y_dtype)
1022
1023            if x_dtype is torch.bool:
1024                x_vals = [1, 1, 1]
1025            if y_dtype is torch.bool:
1026                y_vals = [1, 1, 1]
1027
1028            res_dtype = torch.result_type(x, y)
1029            expected_res = torch.tensor(x_vals + y_vals, device=device, dtype=res_dtype)
1030            res = torch.cat([x, y])
1031            self.assertEqual(res, expected_res, exact_dtype=True)
1032
1033            # cat: full and an empty tensor.
1034            y = torch.tensor([], device=device, dtype=y_dtype)
1035            res_dtype = torch.result_type(x, y)
1036            expected_res = torch.tensor(x_vals + [], device=device, dtype=res_dtype)
1037            res = torch.cat([x, y])
1038            self.assertEqual(res, expected_res, exact_dtype=True)
1039
1040    @onlyNativeDeviceTypes
1041    def test_cat_out_different_dtypes(self, device):
1042        dtypes = all_types_and_complex_and(torch.half)
1043        for x_dtype, y_dtype, out_dtype in itertools.product(dtypes, dtypes, dtypes):
1044            out = torch.zeros(6, device=device, dtype=out_dtype)
1045            x = torch.tensor([1, 2, 3], device=device, dtype=x_dtype)
1046            y = torch.tensor([4, 5, 6], device=device, dtype=y_dtype)
1047            expected_out = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=out_dtype)
1048            if (((x_dtype.is_floating_point or y_dtype.is_floating_point)
1049                    and not (out_dtype.is_floating_point or out_dtype.is_complex))
1050                    or ((x_dtype.is_complex or y_dtype.is_complex) and not out_dtype.is_complex)):
1051                # This combinations do not support type conversion to a different class out type
1052                with self.assertRaises(RuntimeError):
1053                    torch.cat([x, y], out=out)
1054            else:
1055                torch.cat([x, y], out=out)
1056                self.assertEqual(out, expected_out, exact_dtype=True)
1057
1058    # Verfies that unary ops require matching out types
1059    @onlyNativeDeviceTypes
1060    @dtypes(*itertools.product((torch.int64,
1061                                torch.float32, torch.float64,
1062                                torch.complex64, torch.complex128),
1063                               (torch.int64,
1064                                torch.float32, torch.float64,
1065                                torch.complex64, torch.complex128)))
1066    def test_unary_op_out_casting(self, device, dtypes):
1067        t = torch.tensor((1), dtype=dtypes[0], device=device)
1068        out = torch.empty(0, dtype=dtypes[1], device=device)
1069
1070        ops = (torch.neg, torch.floor, torch.ceil)
1071        float_and_int_only_ops = {torch.floor, torch.ceil}
1072        real_only_ops = {torch.floor, torch.ceil}
1073        for op in ops:
1074            if dtypes[0] is not dtypes[1]:
1075                with self.assertRaises(RuntimeError):
1076                    op(t, out=out)
1077            elif op in real_only_ops and dtypes[0].is_complex:
1078                with self.assertRaises(RuntimeError):
1079                    op(t, out=out)
1080            elif (
1081                    op in float_and_int_only_ops
1082                    and (not dtypes[0].is_floating_point and not dtypes[0].is_complex)
1083                    and (not (dtypes[0] == torch.int64 and dtypes[1] == torch.int64))
1084                    and device != "meta"
1085            ):
1086                with self.assertRaises(RuntimeError):
1087                    op(t, out=out)
1088            else:
1089                self.assertEqual(op(t, out=out), op(t))
1090                self.assertEqual(op(t, out=out), out)
1091
1092    # Verifies that the out= argument doesn't affect the computation, that
1093    # is, out = op(...) and op(..., out=out) produce the same result.
1094    @onlyNativeDeviceTypes
1095    @skipMeta
1096    def test_computation_ignores_out(self, device):
1097        t = torch.tensor(33000, dtype=torch.float16, device=device)
1098        out = torch.empty(0, dtype=torch.float64, device=device)
1099        result = torch.add(t, t, out=out)
1100        self.assertEqual(result, t + t, exact_dtype=False)
1101        self.assertNotEqual(result, t.double() + t, exact_dtype=False)
1102
1103        a = torch.tensor(1.5, dtype=torch.float16, device=device)
1104        b = torch.tensor(.666, dtype=torch.float16, device=device)
1105        result = torch.true_divide(a, b, out=out)
1106        self.assertEqual(result, a / b, exact_dtype=False)
1107        self.assertNotEqual(result, a.double() / a, exact_dtype=False)
1108
1109        a = torch.tensor(5, dtype=torch.uint8, device=device)
1110        b = torch.tensor(8, dtype=torch.uint8, device=device)
1111        result = torch.sub(a, b, out=out)
1112        self.assertEqual(result, a - b, exact_dtype=False)
1113        self.assertNotEqual(result, a.double() - b, exact_dtype=False)
1114
1115    @onlyNativeDeviceTypes
1116    @dtypes(*itertools.product((torch.bool, torch.int, torch.float, torch.double), repeat=3))
1117    def test_clamp_type_promotion(self, device, dtypes):
1118        dtype0, dtype1, dtype2 = dtypes
1119        S = 4
1120
1121        def make_tensor(size, dtype):
1122            if dtype == torch.bool:
1123                return torch.randint(2, size, dtype=dtype, device=device)
1124            elif dtype == torch.int:
1125                return torch.randint(10, size, dtype=dtype, device=device)
1126            else:
1127                return torch.randn(size, dtype=dtype, device=device)
1128        min_t = make_tensor((S,), dtype1)
1129        max_t = make_tensor((S,), dtype2)
1130        mins = (min_t, min_t[0], min_t[0].item())
1131        maxs = (max_t, max_t[0], max_t[0].item())
1132        inp = make_tensor((S,), dtype0)
1133        for min_v, max_v in itertools.product(mins, maxs):
1134            if type(max_v) != type(min_v):
1135                continue
1136            if isinstance(min_v, torch.Tensor) and min_v.ndim == 0 and max_v.ndim == 0:
1137                continue  # 0d tensors go to scalar overload, and it's tested separately
1138
1139            def expected_type(inp, max, min):
1140                arg1, arg2 = max, min
1141                if isinstance(max, torch.Tensor) and max.ndim == 0:
1142                    # first do a maybe dimensional boundary
1143                    arg1, arg2 = min, max
1144                exp_type = torch.result_type(inp, arg1)
1145                inp_new = torch.empty_like(inp, dtype=exp_type)
1146                return torch.result_type(inp_new, arg2)
1147            exp_type = expected_type(inp, min_v, max_v)
1148            if exp_type != torch.bool:
1149                actual = torch.clamp(inp, min_v, max_v)
1150                inps = [x.to(exp_type) if isinstance(x, torch.Tensor) else x for x in (inp, min_v, max_v)]
1151                expected = torch.clamp(inps[0], inps[1], inps[2])
1152                self.assertEqual(actual, expected)
1153                if inp.dtype in floating_types() or exp_type == inp.dtype:
1154                    actual = torch.clamp_(inp, min_v, max_v)
1155                    self.assertEqual(actual, expected, exact_dtype=False)
1156        for val in mins:
1157            def expected_type(inp, val):
1158                return torch.result_type(inp, val)
1159            exp_type = expected_type(inp, val)
1160            if exp_type != torch.bool:
1161                actual = torch.clamp_min(inp, val)
1162                inps = [x.to(exp_type) if isinstance(x, torch.Tensor) else x for x in (inp, val)]
1163                expected = torch.clamp_min(inps[0], inps[1])
1164                self.assertEqual(actual.dtype, exp_type)
1165                self.assertEqual(actual, expected)
1166                if inp.dtype == exp_type:
1167                    actual = torch.clamp_min_(inp, val)
1168                    self.assertEqual(actual, expected)
1169                actual = torch.clamp_max(inp, val)
1170                expected = torch.clamp_max(inps[0], inps[1])
1171                self.assertEqual(actual, expected)
1172                if inp.dtype in floating_types() or exp_type == inp.dtype:
1173                    actual = torch.clamp_max_(inp, val)
1174                    self.assertEqual(actual, expected, exact_dtype=False)
1175
1176    @onlyNativeDeviceTypes
1177    def test_ternary_out_promotion(self, device):
1178        for op in [torch.addcdiv, torch.addcmul]:
1179            for dtype in [torch.float32, torch.cfloat]:
1180                prom_dtype = torch.float64 if dtype is torch.float32 else torch.cdouble if dtype is torch.cfloat else dtype
1181                x = torch.rand(3, device=device, dtype=dtype)
1182                y = torch.empty(3, device=device, dtype=dtype)
1183                y_promo = torch.empty(3, device=device, dtype=prom_dtype)
1184                op(x, x, x, out=y)
1185                op(x, x, x, out=y_promo)
1186                self.assertEqual(y, y_promo.to(dtype=dtype))
1187
1188
1189
1190
1191instantiate_device_type_tests(TestTypePromotion, globals())
1192
1193if __name__ == '__main__':
1194    run_tests()
1195