1# Owner(s): ["module: type promotion"] 2 3from functools import wraps 4import itertools 5import unittest 6 7import torch 8 9from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, make_tensor, 10 TEST_NUMPY, set_default_dtype, torch_to_numpy_dtype_dict, 11 numpy_to_torch_dtype_dict, skipIfTorchDynamo, 12 xfailIfTorchDynamo) 13from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, 14 dtypes, onlyCPU, expectedFailureMeta, skipMeta) 15from torch.testing._internal.common_dtype import ( 16 all_types_and_complex_and, get_all_math_dtypes, floating_types, get_all_dtypes, 17 float_to_corresponding_complex_type_map, 18) 19 20 21import numpy as np 22import operator 23 24# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for 25# sharding on sandcastle. This line silences flake warnings 26load_tests = load_tests 27 28# Not thread-safe decorator that runs the decorated test once with 29# the default dtype being torch.float and again with the default dtype 30# being torch.double. 31def float_double_default_dtype(fn): 32 @wraps(fn) 33 def wrapped_fn(*args, **kwargs): 34 with set_default_dtype(torch.float): 35 fn(*args, **kwargs) 36 with set_default_dtype(torch.double): 37 fn(*args, **kwargs) 38 39 return wrapped_fn 40 41class TestTypePromotion(TestCase): 42 43 # In-place operations don't promote. 44 # `int+float -> float` but `int.add_(float)` is rejected as an error. 45 # Promoting inplace would require re-allocating and copying the memory of the 46 # tensor data, since element size could change. 47 # https://github.com/pytorch/pytorch/issues/127049 48 @xfailIfTorchDynamo 49 @float_double_default_dtype 50 def test_inplace(self, device): 51 int_tensor = torch.ones([4, 4, 4], dtype=torch.int32, device=device) 52 53 self.assertRaisesRegex(RuntimeError, "can't be cast to", lambda: int_tensor.add_(1.5)) 54 55 expected = torch.ones([4, 4, 4], dtype=torch.int32, device=device) 56 57 long_tensor = torch.ones([4, 4, 4], dtype=torch.int64, device=device) 58 int_tensor.add_(long_tensor) 59 int_tensor.add_(1) 60 three = expected + 2 61 self.assertEqual(int_tensor, three) 62 self.assertEqual(int_tensor.dtype, torch.int32) 63 64 bool_tensor = torch.tensor([1, 1, 1], dtype=torch.bool, device=device) 65 uint8_tensor = torch.tensor([1, 1, 1], dtype=torch.uint8, device=device) 66 # We treat bool as a separate category, which means uint8 cannot cast to bool. 67 self.assertRaisesRegex(RuntimeError, "can't be cast to", lambda: bool_tensor.add_(uint8_tensor)) 68 69 # We allow demotion from signed to unsigned, unlike numpy, because: 70 # * We don't want the performance penalty of inspecting scalar values. 71 # * We don't want 'signed' to be considered a distinct 'category' 72 # in promotion rules. 73 # We don't want signed to be a separate category because if it was, 74 # uint16_tensor + 5 would result in a long_tensor, which is not what we want. 75 int16_tensor = torch.tensor([1, 1, 1], dtype=torch.int16, device=device) 76 uint8_tensor *= int16_tensor 77 78 @float_double_default_dtype 79 def test_unsigned(self, device): 80 dont_promote = torch.ones(3, dtype=torch.uint8, device=device) + 5 81 self.assertEqual(dont_promote.dtype, torch.uint8) 82 83 # some basic examples 84 85 @float_double_default_dtype 86 def test_int_promotion(self, device): 87 a = torch.ones([4, 4, 4], dtype=torch.int32, device=device) 88 b = torch.ones([4, 4, 4], dtype=torch.int64, device=device) 89 c = a + b 90 self.assertEqual(c, b + b) 91 self.assertEqual(c.dtype, torch.int64) 92 93 @float_double_default_dtype 94 def test_float_promotion(self, device): 95 def test_promotion(dtype_float, dtype_double): 96 a = torch.ones([4, 4, 4], dtype=dtype_float, device=device) 97 b = torch.ones([4, 4, 4], dtype=dtype_double, device=device) 98 c = a + b 99 self.assertEqual(c, b + b) 100 self.assertEqual(c.dtype, dtype_double) 101 c = b + a 102 self.assertEqual(c, b + b) 103 self.assertEqual(c.dtype, dtype_double) 104 test_promotion(torch.float, torch.double) 105 106 @float_double_default_dtype 107 def test_complex_promotion(self, device): 108 def test_promotion(dtype_float, dtype_double): 109 a = torch.ones([4, 4, 4], dtype=dtype_float, device=device) 110 b = torch.ones([4, 4, 4], dtype=dtype_double, device=device) 111 c = a + b 112 self.assertEqual(c, b + b) 113 self.assertEqual(c.dtype, dtype_double) 114 c = b + a 115 self.assertEqual(c, b + b) 116 self.assertEqual(c.dtype, dtype_double) 117 118 test_promotion(torch.complex64, torch.complex128) 119 120 a = torch.randn(3, dtype=torch.complex64, device=device) 121 self.assertEqual((a * 5).dtype, torch.complex64) 122 # not a "wrapped number" 123 other = torch.tensor(5.5, dtype=torch.double, device=device) 124 self.assertEqual((a + other).dtype, torch.complex64) 125 126 def make_scalar_tensor(dtype): 127 return make_tensor((), dtype=dtype, device=device) 128 129 def make_1d_tensor(dtype): 130 return make_tensor((3,), dtype=dtype, device=device) 131 132 def complex_scalar_tensor_test(s, t): 133 # As per type promotion rules, 134 # Complex Scalar and Float Tensor -> Complex Tensor with Value type of Float Tensor 135 # Complex Scalar and Integral Tensor -> Complex Tensor with Value type of Complex Scalar 136 137 if t.dtype.is_floating_point: 138 # defaults to return complex64 (for bfloat16) 139 expected_dtype = float_to_corresponding_complex_type_map.get(t.dtype, torch.complex64) 140 else: # integral tensor 141 if isinstance(s, torch.Tensor): 142 expected_dtype = s.dtype 143 else: 144 expected_dtype = float_to_corresponding_complex_type_map[torch.get_default_dtype()] 145 self.assertEqual((s * t).dtype, expected_dtype) 146 self.assertEqual((t * s).dtype, expected_dtype) 147 self.assertEqual(torch.result_type(s, t), expected_dtype) 148 self.assertEqual(torch.result_type(t, s), expected_dtype) 149 150 if torch.device(device).type != 'xla': 151 # chalf is not supported on XLA 152 s = make_scalar_tensor(dtype=torch.chalf) 153 # Same Value type 154 t = make_1d_tensor(dtype=torch.half) 155 # 0-D Tensor X 1-D Tensor 156 complex_scalar_tensor_test(s, t) 157 # Python Scalar X 1-D Tensor 158 complex_scalar_tensor_test(s.item(), t) 159 160 # Higher Value Type 161 t = make_1d_tensor(dtype=torch.float) 162 complex_scalar_tensor_test(s, t) 163 complex_scalar_tensor_test(s.item(), t) 164 165 # Special Case 166 t = make_1d_tensor(dtype=torch.bfloat16) 167 complex_scalar_tensor_test(s, t) 168 complex_scalar_tensor_test(s.item(), t) 169 170 # Integral Tensor 171 t = make_1d_tensor(dtype=torch.long) 172 complex_scalar_tensor_test(s, t) 173 complex_scalar_tensor_test(s.item(), t) 174 175 # CFloat Scalar 176 s = make_scalar_tensor(dtype=torch.cfloat) 177 # Lower Value type than CFloat 178 t = make_1d_tensor(dtype=torch.half) 179 complex_scalar_tensor_test(s, t) 180 complex_scalar_tensor_test(s.item(), t) 181 182 # Higher Value type than CFloat 183 t = make_1d_tensor(dtype=torch.double) 184 complex_scalar_tensor_test(s, t) 185 complex_scalar_tensor_test(s.item(), t) 186 187 # Integral Tensor 188 t = make_1d_tensor(dtype=torch.long) 189 # 0-D Tensor X 1-D Tensor 190 complex_scalar_tensor_test(s, t) 191 # Python Scalar X 1-D Tensor 192 complex_scalar_tensor_test(s.item(), t) 193 194 # CDouble Scalar 195 s = make_scalar_tensor(dtype=torch.cdouble) 196 197 # Lower Value type than CDouble 198 t = make_1d_tensor(dtype=torch.float) 199 complex_scalar_tensor_test(s, t) 200 complex_scalar_tensor_test(s.item(), t) 201 202 # Special Case 203 t = make_1d_tensor(dtype=torch.bfloat16) 204 complex_scalar_tensor_test(s, t) 205 complex_scalar_tensor_test(s.item(), t) 206 207 @float_double_default_dtype 208 def test_complex_scalar_mult_tensor_promotion(self, device): 209 a = 1j * torch.ones(2, device=device) 210 a = a + 1j 211 b = torch.tensor([2j, 2j], device=device) 212 self.assertEqual(a, b) 213 self.assertEqual(a.dtype, b.dtype) 214 215 @float_double_default_dtype 216 def test_add_wrapped(self, device): 217 a = torch.ones([4, 4, 4], dtype=torch.int, device=device) 218 b = 1 219 c = a + b 220 self.assertEqual(c, a + a) 221 self.assertEqual(c.dtype, torch.int) 222 223 @float_double_default_dtype 224 def test_int_to_float(self, device): 225 a = torch.ones([4, 4, 4], dtype=torch.int32, device=device) 226 b = torch.ones([4, 4, 4], dtype=torch.float, device=device) 227 c = a + b 228 self.assertEqual(c.dtype, torch.float32) 229 230 # some examples from: 231 # https://github.com/pytorch/pytorch/issues/9515 232 233 @float_double_default_dtype 234 def test_from_issue(self, device): 235 a = torch.rand(3, dtype=torch.float32, device=device) 236 u = torch.tensor([0, 0, 1], dtype=torch.uint8, device=device) 237 self.assertEqual((a * 5).dtype, torch.float32) 238 self.assertEqual((u + 1).dtype, torch.uint8) 239 self.assertEqual((u + 1000).dtype, torch.uint8) # integer overflow 240 241 # not a "wrapped number" 242 other = torch.tensor(5.5, dtype=torch.double, device=device) 243 244 self.assertEqual((u + 5.5).dtype, torch.get_default_dtype()) 245 self.assertEqual((u + other).dtype, torch.double) 246 # adding a 0-dim tensor to a float doesn't promote to double unless first 247 # type was integral. 248 self.assertEqual((a + other).dtype, torch.float32) 249 250 @float_double_default_dtype 251 def test_half(self, device): 252 half = torch.tensor(5.5, dtype=torch.float16, device=device) 253 self.assertEqual((half + 2.2).dtype, torch.float16) 254 self.assertEqual((half + 100000).dtype, torch.float16) # inf 255 default_tensor = torch.tensor(100000.0, device=device) 256 self.assertEqual((half + default_tensor).dtype, torch.get_default_dtype()) 257 258 def test_bfloat16(self, device): 259 # with scalar 260 bf = torch.tensor(5.5, dtype=torch.bfloat16, device=device) 261 for scalar in (2.2, 5, 100000): # bf + 100000 is inf 262 self.assertEqual((bf + scalar).dtype, torch.bfloat16) 263 self.assertEqual(scalar + bf, bf + scalar) 264 265 for scalar in (complex(1, 1), complex(-2, 0), complex(0, -3)): 266 self.assertEqual((bf + scalar).dtype, torch.cfloat) 267 self.assertEqual(bf + scalar, scalar + bf) 268 269 # with tensor 270 for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): 271 t = torch.tensor(1, dtype=dtype, device=device) 272 self.assertEqual(bf + t, t + bf) 273 if dtype in (torch.float16, torch.float32, torch.float64, torch.cfloat, torch.cdouble): 274 # Handles bfloat16 x float16 -> float32 promotion 275 expected_dtype = dtype if dtype != torch.half else torch.float32 276 elif dtype is torch.chalf: 277 expected_dtype = torch.cfloat 278 elif dtype in (torch.bool, torch.uint8, 279 torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16): 280 expected_dtype = torch.bfloat16 281 else: 282 raise AssertionError(f'Missing dtype {dtype} not tested.') 283 284 self.assertEqual(torch.promote_types(dtype, torch.bfloat16), expected_dtype) 285 self.assertEqual(torch.promote_types(torch.bfloat16, dtype), expected_dtype) 286 self.assertEqual((bf + t).dtype, expected_dtype) 287 288 @onlyNativeDeviceTypes 289 def test_complex_half(self, device): 290 # with scalar 291 chalf = torch.tensor(5.5, dtype=torch.chalf, device=device) 292 for scalar in (2.2, 5, 100000): # chalf + 100000 is inf 293 self.assertEqual((chalf * scalar).dtype, torch.chalf) 294 self.assertEqual(scalar * chalf, chalf * scalar) 295 296 for scalar in (complex(1, 1), complex(-2, 0), complex(0, -3)): 297 self.assertEqual((chalf * scalar).dtype, torch.chalf) 298 self.assertEqual(chalf * scalar, scalar * chalf) 299 300 # with tensor 301 dtypes = all_types_and_complex_and(torch.chalf, torch.half, torch.bfloat16, torch.bool) 302 for dtype in dtypes: 303 t = torch.tensor(1, dtype=dtype, device=device) 304 self.assertEqual(chalf * t, t * chalf) 305 if dtype in (torch.float16, torch.chalf): 306 expected_dtype = torch.chalf 307 elif dtype in (torch.float, torch.double, torch.bfloat16): 308 expected_dtype = torch.cdouble if dtype is torch.double else torch.cfloat 309 elif dtype in (torch.cfloat, torch.cdouble): 310 expected_dtype = dtype 311 elif dtype in (torch.bool, torch.uint8, 312 torch.int8, torch.int16, torch.int32, torch.int64): 313 expected_dtype = torch.chalf 314 else: 315 raise AssertionError(f'Missing dtype {dtype} not tested.') 316 317 self.assertEqual(torch.promote_types(dtype, torch.chalf), expected_dtype) 318 self.assertEqual(torch.promote_types(torch.chalf, dtype), expected_dtype) 319 self.assertEqual((chalf * t).dtype, expected_dtype) 320 321 @float_double_default_dtype 322 def test_alternate_result(self, device): 323 x = torch.tensor([1, 1, 1, 1], dtype=torch.float, device=device) 324 o = torch.tensor([0, 0, 0, 0], dtype=torch.long, device=device) 325 self.assertRaisesRegex(RuntimeError, 326 "can't be cast to", 327 lambda: torch.add(x, x, out=o)) 328 d = torch.tensor([1, 1, 1, 1], dtype=torch.double, device=device) 329 torch.add(x, x, out=d) 330 self.assertEqual(d.dtype, torch.double) 331 x = x.to(torch.double) 332 self.assertEqual(x + x, d) 333 334 @float_double_default_dtype 335 def test_mixed_type_backward(self, device): 336 f = torch.ones([3, 3], dtype=torch.float, requires_grad=True, device=device) 337 ten = torch.tensor([10.], dtype=torch.double, device=device) 338 tens = f * ten 339 s = (tens + 2).sum() 340 s.backward() 341 expected = f.grad.to(torch.double) 342 self.assertEqual(tens, expected) 343 344 # If we don't convert the returned grad_input to the actual input type 345 # we get an error like: 346 # RuntimeError: Function SubBackward0 returned an invalid gradient at index 0 - expected type \ 347 # torch.FloatTensor but got torch.DoubleTensor 348 f_dtypes = [torch.float, torch.double] 349 if self.device_type == 'cuda': 350 f_dtypes = f_dtypes + [torch.half] 351 i_dtypes = [torch.int, torch.long] 352 for func in [torch.add, torch.sub, torch.rsub, torch.mul, torch.div]: 353 for dtype1, dtype2 in itertools.product(f_dtypes, f_dtypes + i_dtypes): 354 x = torch.ones(10, requires_grad=True, dtype=dtype1, device=device) 355 y = torch.ones(10, dtype=dtype2, device=device) 356 func(x, y).sum().backward() 357 358 def _get_test_tensor(self, device, dtype, remove_zeros=False): 359 shape = [5, 5, 5] 360 if dtype == torch.bool: 361 tensor = torch.randint(int(remove_zeros), 2, shape, device=device, dtype=dtype) 362 elif dtype.is_floating_point or dtype.is_complex: 363 # "_th_normal_ not supported on CPUType for Half" so simpler create and convert 364 tensor = torch.randn(shape, device=device) 365 tensor = tensor.to(dtype) 366 if remove_zeros: 367 tensor[torch.abs(tensor) < 0.05] = 5 368 else: 369 tensor = torch.randint(-5 if dtype.is_signed else 0, 10, shape, device=device, dtype=dtype) 370 if remove_zeros: 371 tensor[tensor == 0] = 5 372 return tensor 373 374 # verifies that torch.<op>(first, second) is the same as 375 # torch.<op>(first.to(common_dtype), second.to(common_dtype)) in cases where that should hold. 376 @float_double_default_dtype 377 def test_many_promotions(self, device): 378 # Can also include half on CPU in cases where it will be promoted to a 379 # supported dtype 380 dtypes1 = get_all_math_dtypes('cuda') 381 dtypes2 = get_all_math_dtypes(device) 382 ops = [torch.add, torch.sub, torch.mul, torch.div, torch.rsub] 383 for dt1, dt2 in itertools.product(dtypes1, dtypes2): 384 for op, non_contiguous in itertools.product(ops, [True, False]): 385 common_dtype = torch.promote_types(dt1, dt2) 386 if common_dtype == torch.half and self.device_type == 'cpu': 387 continue 388 if op == torch.sub and common_dtype != torch.bool: 389 # Subtraction, the `-` operator, with a bool tensor is not supported. 390 continue 391 first = self._get_test_tensor(device, dt1) 392 second = self._get_test_tensor(device, dt2, op == torch.div) 393 # test ops with non-contiguous tensors 394 if non_contiguous: 395 first = first.transpose(0, 2) 396 second = second.transpose(2, 1) 397 self.assertNotEqual(first.stride(), second.stride(), 398 msg="some non-contiguous issues could be missed if tensors have same strides") 399 400 self.assertEqual(not first.is_contiguous(), non_contiguous) 401 self.assertEqual(not second.is_contiguous(), non_contiguous) 402 result = op(first, second) 403 expected = op(first.to(common_dtype), second.to(common_dtype)) 404 self.assertEqual(result.dtype, expected.dtype, msg=f'{op.__name__} with {dt1}, {dt2}') 405 self.assertEqual(result, expected, msg=f'{op.__name__} with {dt1}, {dt2}') 406 407 @float_double_default_dtype 408 def test_non_promoting_ops(self, device): 409 x = torch.ones(4, dtype=torch.double, device=device) 410 with self.assertRaises(RuntimeError): 411 torch.lerp(x, torch.ones(4, dtype=torch.float, device=device), 1) 412 413 @float_double_default_dtype 414 def test_alpha_mismatch(self, device): 415 x = torch.ones(4, dtype=torch.int, device=device) 416 err = 'alpha must not be' 417 self.assertRaisesRegex(RuntimeError, err, 418 lambda: torch.add(x, x, alpha=1.1)) 419 x = x.to(torch.bool) 420 self.assertRaisesRegex(RuntimeError, err, 421 lambda: torch.add(x, x, alpha=1.1)) 422 self.assertEqual(x + x, torch.add(x, x, alpha=True)) 423 424 @float_double_default_dtype 425 def test_booleans(self, device): 426 onedim = torch.tensor([True], device=device) 427 428 self.assertEqual(onedim + onedim, onedim) 429 self.assertEqual(onedim + True, onedim) 430 self.assertEqual(torch.add(True, True), True) 431 self.assertEqual(torch.add(False, False), False) 432 self.assertEqual(torch.add(False, True), True) 433 434 self.assertRaisesRegex(RuntimeError, "Boolean alpha only supported", 435 lambda: torch.add(1, 1, alpha=True)) 436 self.assertEqual(torch.add(torch.tensor(True, device=device), 437 torch.tensor(True, device=device), True), 438 torch.tensor(True, device=device)) 439 440 @skipIfTorchDynamo("Not a TorchDynamo suitable test") 441 @float_double_default_dtype 442 def test_create_bool_tensors(self, device): 443 expected = torch.tensor([0], dtype=torch.int64, device=device) 444 self.assertEqual(torch.arange(False, True, device=device), expected) 445 self.assertEqual(torch.arange(True, device=device), expected) 446 expected = torch.tensor([0, 0.5], dtype=torch.get_default_dtype(), device=device) 447 self.assertEqual(torch.arange(False, True, 0.5, device=device), expected) 448 expected = torch.ones(0, dtype=torch.int64, device=device) 449 self.assertEqual(torch.arange(False, False, device=device), expected) 450 451 bool_tensor_lin = torch.linspace(False, True, steps=100, device=device) 452 int_tensor_lin = torch.linspace(0, 1, steps=100, device=device) 453 self.assertEqual(bool_tensor_lin, int_tensor_lin) 454 bool_tensor_log = torch.linspace(False, True, steps=100, device=device) 455 int_tensor_log = torch.linspace(0, 1, steps=100, device=device) 456 self.assertEqual(bool_tensor_log, int_tensor_log) 457 458 # this seems like odd behavior but ints also create float tensors, numpy doesn't have this function. 459 self.assertEqual(torch.scalar_tensor(False, device=device), torch.tensor(0., device=device)) 460 461 @dtypes(*itertools.product(all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), 462 all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))) 463 def test_result_type(self, device, dtypes): 464 "Test result_type for tensor vs tensor and scalar vs scalar." 465 466 def _get_dtype(x): 467 "Get the dtype of x if x is a tensor. If x is a scalar, get its corresponding dtype if it were a tensor." 468 if torch.is_tensor(x): 469 return x.dtype 470 elif isinstance(x, bool): 471 return torch.bool 472 elif isinstance(x, int): 473 return torch.int64 474 elif isinstance(x, float): 475 return torch.float32 476 elif isinstance(x, complex): 477 return torch.complex64 478 else: 479 raise AssertionError(f"Unknown type {x}") 480 481 # tensor against tensor 482 a_tensor = torch.tensor((0, 1), device=device, dtype=dtypes[0]) 483 a_single_tensor = torch.tensor(1, device=device, dtype=dtypes[0]) 484 a_scalar = a_single_tensor.item() 485 b_tensor = torch.tensor((1, 0), device=device, dtype=dtypes[1]) 486 b_single_tensor = torch.tensor(1, device=device, dtype=dtypes[1]) 487 b_scalar = b_single_tensor.item() 488 combo = ((a_tensor, a_single_tensor, a_scalar), (b_tensor, b_single_tensor, b_scalar)) 489 for a, b in itertools.product(*combo): 490 dtype_a = _get_dtype(a) 491 dtype_b = _get_dtype(b) 492 try: 493 result = a + b 494 except RuntimeError: 495 with self.assertRaises(RuntimeError): 496 torch.promote_types(dtype_a, dtype_b) 497 with self.assertRaises(RuntimeError): 498 torch.result_type(a, b) 499 else: 500 dtype_res = _get_dtype(result) 501 if a is a_scalar and b is b_scalar and dtype_a == torch.bool and dtype_b == torch.bool: 502 # special case: in Python, True + True is an integer 503 self.assertEqual(dtype_res, torch.int64, f"a == {a}, b == {b}") 504 else: 505 self.assertEqual(dtype_res, torch.result_type(a, b), f"a == {a}, b == {b}") 506 if a is a_scalar and b is b_scalar: # Python internal type determination is good enough in this case 507 continue 508 if any(a is a0 and b is b0 for a0, b0 in zip(*combo)): # a and b belong to the same class 509 self.assertEqual(dtype_res, torch.promote_types(dtype_a, dtype_b), f"a == {a}, b == {b}") 510 511 # Spot check some result type for tensor against scalar (including single-element tensor). 512 @float_double_default_dtype 513 def test_result_type_tensor_vs_scalar(self, device): 514 def _test_spot(a, b, res_dtype): 515 self.assertEqual(torch.result_type(a, b), res_dtype) 516 self.assertEqual(torch.result_type(b, a), res_dtype) 517 518 _test_spot(torch.tensor([1, 2], dtype=torch.half, device=device), 519 torch.tensor(1, dtype=torch.long, device=device), torch.half) 520 _test_spot(torch.tensor(1, dtype=torch.float, device=device), 521 torch.tensor([1, 2], dtype=torch.double, device=device), torch.double) 522 _test_spot(torch.tensor(1, dtype=torch.int, device=device), 1, torch.int) 523 _test_spot(torch.tensor(1, device=device), 1., torch.get_default_dtype()) 524 _test_spot(torch.tensor(1, dtype=torch.long, device=device), 525 torch.tensor([1, 1], dtype=torch.int, device=device), torch.int) 526 _test_spot(torch.tensor([1., 1.], dtype=torch.float, device=device), 1., torch.float) 527 _test_spot(torch.tensor([1., 1.], dtype=torch.complex64, device=device), 528 torch.tensor(1., dtype=torch.complex128, device=device), torch.complex64) 529 _test_spot(torch.tensor([1., 1.], dtype=torch.complex128, device=device), 530 torch.tensor(1., dtype=torch.complex64, device=device), torch.complex128) 531 _test_spot(torch.tensor([1, 1], dtype=torch.bool, device=device), 1., torch.get_default_dtype()) 532 533 @float_double_default_dtype 534 def test_can_cast(self, device): 535 self.assertTrue(torch.can_cast(torch.double, torch.float)) 536 self.assertFalse(torch.can_cast(torch.float, torch.int)) 537 538 @float_double_default_dtype 539 def test_comparison_ops_with_type_promotion(self, device): 540 value_for_type = { 541 torch.uint8: (1 << 5), 542 torch.int8: (1 << 5), 543 torch.int16: (1 << 10), 544 torch.int32: (1 << 20), 545 torch.int64: (1 << 35), 546 torch.float16: (1 << 10), 547 torch.float32: (1 << 20), 548 torch.float64: (1 << 35), 549 torch.complex64: (1 << 20), 550 torch.complex128: (1 << 35) 551 } 552 comparison_ops = [ 553 dict( 554 name="lt", 555 out_op=lambda x, y, d: torch.lt(x, y, out=torch.empty(0, dtype=torch.bool, device=d)), 556 ret_op=lambda x, y: torch.lt(x, y), 557 compare_op=operator.lt, 558 ), 559 dict( 560 name="le", 561 out_op=lambda x, y, d: torch.le(x, y, out=torch.empty(0, dtype=torch.bool, device=d)), 562 ret_op=lambda x, y: torch.le(x, y), 563 compare_op=operator.le, 564 ), 565 dict( 566 name="gt", 567 out_op=lambda x, y, d: torch.gt(x, y, out=torch.empty(0, dtype=torch.bool, device=d)), 568 ret_op=lambda x, y: torch.gt(x, y), 569 compare_op=operator.gt, 570 ), 571 dict( 572 name="ge", 573 out_op=lambda x, y, d: torch.ge(x, y, out=torch.empty(0, dtype=torch.bool, device=d)), 574 ret_op=lambda x, y: torch.ge(x, y), 575 compare_op=operator.ge, 576 ), 577 dict( 578 name="eq", 579 out_op=lambda x, y, d: torch.eq(x, y, out=torch.empty(0, dtype=torch.bool, device=d)), 580 ret_op=lambda x, y: torch.eq(x, y), 581 compare_op=operator.eq, 582 ), 583 dict( 584 name="ne", 585 out_op=lambda x, y, d: torch.ne(x, y, out=torch.empty(0, dtype=torch.bool, device=d)), 586 ret_op=lambda x, y: torch.ne(x, y), 587 compare_op=operator.ne, 588 ), 589 ] 590 for op in comparison_ops: 591 for dt1 in get_all_math_dtypes(device): 592 for dt2 in get_all_math_dtypes(device): 593 if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"): 594 continue 595 val1 = value_for_type[dt1] 596 val2 = value_for_type[dt2] 597 t1 = torch.tensor([val1], dtype=dt1, device=device) 598 t2 = torch.tensor([val2], dtype=dt2, device=device) 599 expected = torch.tensor([op["compare_op"](val1, val2)], dtype=torch.bool) 600 601 out_res = op["out_op"](t1, t2, device) 602 self.assertEqual(out_res, expected) 603 self.assertTrue(out_res.dtype == torch.bool) 604 self.assertTrue(t1.dtype == dt1) 605 self.assertTrue(t2.dtype == dt2) 606 607 out_res = op["ret_op"](t1, t2) 608 self.assertEqual(out_res, expected) 609 self.assertTrue(out_res.dtype == torch.bool) 610 self.assertTrue(t1.dtype == dt1) 611 self.assertTrue(t2.dtype == dt2) 612 613 # test that comparing a zero dim tensor with another zero dim tensor has type promotion behavior 614 t1 = torch.tensor(val1, dtype=dt1, device=device) 615 t2 = torch.tensor(val2, dtype=dt2, device=device) 616 expected = torch.tensor(op["compare_op"](val1, val2), dtype=torch.bool) 617 618 out_res = op["out_op"](t1, t2, device) 619 self.assertEqual(out_res, expected) 620 self.assertTrue(out_res.dtype == torch.bool) 621 self.assertTrue(t1.dtype == dt1) 622 self.assertTrue(t2.dtype == dt2) 623 624 out_res = op["ret_op"](t1, t2) 625 self.assertEqual(out_res, expected) 626 self.assertTrue(out_res.dtype == torch.bool) 627 self.assertTrue(t1.dtype == dt1) 628 self.assertTrue(t2.dtype == dt2) 629 630 # XLA tests fail for self.assertRaises for complex dtypes 631 @onlyNativeDeviceTypes 632 def test_complex_assertraises(self, device): 633 comparison_ops = [ 634 dict(name="lt", compare_op=operator.lt, ), 635 dict(name="le", compare_op=operator.le, ), 636 dict(name="gt", compare_op=operator.gt, ), 637 dict(name="ge", compare_op=operator.ge, ), 638 dict(name="eq", compare_op=operator.eq, ), 639 dict(name="ne", compare_op=operator.ne, ), 640 ] 641 for op in comparison_ops: 642 is_cuda = torch.device(device).type == 'cuda' 643 dtypes = get_all_dtypes(include_half=is_cuda, 644 include_bfloat16=False, include_bool=False, 645 include_complex32=True) 646 647 for dt1, dt2 in itertools.product(dtypes, dtypes): 648 if (dt1.is_complex or dt2.is_complex) and not (op["name"] == "eq" or op["name"] == "ne"): 649 u = torch.tensor([1], dtype=dt1, device=device) 650 v = torch.tensor([2], dtype=dt2, device=device) 651 self.assertRaises(RuntimeError, lambda: torch.tensor([op["compare_op"](u, v)], dtype=torch.bool)) 652 653 @float_double_default_dtype 654 def test_lt_with_type_promotion(self, device): 655 for dt in get_all_math_dtypes(device): 656 x = torch.tensor([0], dtype=dt, device=device) 657 expected = torch.tensor([True], dtype=torch.bool, device=device) 658 659 if dt.is_complex: 660 continue 661 662 actual = x < 0.5 663 self.assertTrue(actual, expected) 664 self.assertTrue(actual.dtype == torch.bool) 665 666 actual = x < torch.tensor(0.5, device=device) 667 self.assertTrue(actual, expected) 668 self.assertTrue(actual.dtype == torch.bool) 669 670 x = torch.tensor(0, dtype=dt, device=device) 671 expected = torch.tensor(True, dtype=torch.bool, device=device) 672 actual = x < 0.5 673 self.assertTrue(actual, expected) 674 self.assertTrue(actual.dtype == torch.bool) 675 676 actual = x < torch.tensor(0.5, device=device) 677 self.assertTrue(actual, expected) 678 self.assertTrue(actual.dtype == torch.bool) 679 680 @float_double_default_dtype 681 def test_promote_types(self, device): 682 self.assertEqual(torch.promote_types(torch.float, torch.int), torch.float) 683 self.assertEqual(torch.promote_types(torch.float, torch.double), torch.double) 684 self.assertEqual(torch.promote_types(torch.int, torch.uint8), torch.int) 685 with self.assertRaisesRegex(RuntimeError, "Promotion for Float8 Types is not supported"): 686 self.assertEqual(torch.promote_types(torch.float8_e5m2, torch.float), torch.float) 687 with self.assertRaisesRegex(RuntimeError, "Promotion for Float8 Types is not supported"): 688 self.assertEqual(torch.promote_types(torch.float, torch.float8_e4m3fn), torch.float) 689 690 @float_double_default_dtype 691 def test_promote_self(self, device): 692 for dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.chalf, torch.bool, 693 torch.float8_e5m2, torch.float8_e4m3fn): 694 self.assertEqual(torch.promote_types(dtype, dtype), dtype) 695 696 @expectedFailureMeta 697 @float_double_default_dtype 698 def test_indexing_fail(self, device): 699 # https://github.com/pytorch/pytorch/issues/28010 700 a = torch.ones(5, 2, dtype=torch.double, device=device) 701 b = torch.zeros(5, dtype=torch.int, device=device) 702 with self.assertRaises(RuntimeError): 703 a[:, [1]] = b.unsqueeze(-1) 704 705 @float_double_default_dtype 706 def test_indexing(self, device): 707 x = torch.ones(5, 2, dtype=torch.double, device=device) 708 y = torch.zeros(5, dtype=torch.double, device=device) 709 x[:, [1]] = y.unsqueeze(-1) 710 expected = torch.tensor([(1, 0), (1, 0), (1, 0), (1, 0), (1, 0)], dtype=torch.double, device=device) 711 self.assertEqual(x, expected) 712 713 714 # https://github.com/pytorch/pytorch/issues/27824 715 tmp = torch.ones(9, 9, dtype=torch.float, device=device) 716 mask = torch.ones(10, 10, dtype=torch.uint8, device=device) 717 result = tmp + mask[1:, 1:] 718 expected = torch.full([9, 9], 2., dtype=torch.float, device=device).fill_(2.) 719 self.assertEqual(result, expected) 720 721 @float_double_default_dtype 722 def test_transpose(self, device): 723 # https://github.com/pytorch/pytorch/issues/28502 724 a = torch.tensor([[True, True], [False, True]], device=device) 725 self.assertEqual(a.t() == 0, a.t() == False) # noqa: E712 726 727 @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 728 @float_double_default_dtype 729 def test_div_promotion(self, device, dtype): 730 for op in (torch.div, torch.true_divide): 731 dividend = (torch.randn(5, device=device) * 100).to(dtype) 732 divisor = torch.arange(1, 6, device=device).to(dtype) 733 734 # Tests tensor/tensor division 735 casting_result = dividend.to(torch.get_default_dtype()) / divisor.to(torch.get_default_dtype()) 736 self.assertEqual(casting_result, op(dividend, divisor)) 737 738 # Tests tensor/scalar division 739 casting_result = dividend.to(torch.get_default_dtype()) / 2 740 self.assertEqual(casting_result, op(dividend, 2.)) 741 742 @onlyNativeDeviceTypes 743 @dtypes(torch.float, torch.double, 744 torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 745 def test_div_promotion_out(self, device, dtype): 746 for op in (torch.div, torch.true_divide): 747 dividend = (torch.randn(5, device=device) * 100).to(dtype) 748 divisor = torch.arange(1, 6, device=device).to(dtype) 749 750 # Tests that requests for an integer quotient fail 751 if not dtype.is_floating_point: 752 integral_quotient = torch.empty(5, device=device, dtype=dtype) 753 with self.assertRaises(RuntimeError): 754 op(dividend, divisor, out=integral_quotient) 755 with self.assertRaises(RuntimeError): 756 op(dividend, 2, out=integral_quotient) 757 else: 758 # Tests that requests for a floating quotient succeed 759 floating_quotient = torch.empty(5, device=device, dtype=dtype) 760 div_result = dividend / divisor 761 self.assertEqual(div_result, 762 op(dividend, divisor, out=floating_quotient)) 763 self.assertEqual(dividend / 2, 764 op(dividend, 2, out=floating_quotient)) 765 766 @dtypes(torch.float, torch.double, 767 torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64) 768 def test_div_promotion_inplace(self, device, dtype): 769 for op in (torch.Tensor.div_, torch.Tensor.true_divide_): 770 dividend = (torch.randn(5, device=device) * 100).to(dtype) 771 divisor = torch.arange(1, 6, device=device).to(dtype) 772 773 # Tests that requests for an integer quotient fail 774 if not dtype.is_floating_point: 775 with self.assertRaises(RuntimeError): 776 op(dividend, divisor) 777 with self.assertRaises(RuntimeError): 778 op(dividend, 2) 779 else: 780 # Tests that requests for a floating quotient succeed 781 div_result = dividend.clone().div_(divisor) 782 self.assertEqual(div_result, op(dividend.clone(), divisor)) 783 self.assertEqual(dividend.clone().div_(2), op(dividend.clone(), 2)) 784 785 def _test_sparse_op_input_tensors(self, device, dtype, coalesced, zeros=True): 786 t = self._get_test_tensor(device, dtype, not zeros) 787 if zeros and dtype != torch.bool: 788 # ensure sparsity. Bool should already have sufficient sparsity. 789 mask = self._get_test_tensor(device, torch.bool) 790 t = t * mask 791 792 if coalesced: 793 s = t.to_sparse() 794 else: 795 s = t.to_sparse() 796 indices = torch.cat((s.indices(), s.indices()), 1) 797 values = torch.cat((s.values(), s.values()), 0) 798 s = torch.sparse_coo_tensor(indices=indices, values=values, size=s.size(), dtype=dtype, device=device) 799 t = s.to_dense() 800 self.assertEqual(s.is_coalesced(), coalesced) 801 self.assertEqual(s.dtype, dtype) 802 self.assertEqual(t.dtype, s.dtype) 803 return t, s 804 805 def _get_precision(self, dtype, coalesced): 806 if dtype == torch.half and not coalesced: 807 # very low precision for uncoalesced float16 sparse tensors since 808 # ops like (s1 + s2).to_dense() will add four low-precision 809 # floating point values. 810 return 5e-2 811 if dtype == torch.half: 812 return 1e-3 813 # uses default 814 return None 815 816 def _test_sparse_op(self, op_name, inplace, dtype1, dtype2, device, coalesced): 817 if dtype1.is_complex or dtype2.is_complex: 818 return 819 820 suffix = '_' if inplace else '' 821 err = f"{' coalesced' if coalesced else 'uncoalesced'} {op_name + suffix}({dtype1}, {dtype2})" 822 823 def op(t1, t2, suf=None): 824 suf = suffix if suf is None else suf 825 return getattr(t1, op_name + suf)(t2) 826 827 add_sub = op_name == 'add' or op_name == 'sub' 828 829 (dense1, sparse1) = self._test_sparse_op_input_tensors(device, dtype1, coalesced) 830 (dense2, sparse2) = self._test_sparse_op_input_tensors(device, dtype2, coalesced, op_name != 'div') 831 832 common_dtype = torch.result_type(dense1, dense2) 833 if self.device_type == 'cpu' and common_dtype == torch.half: 834 self.assertRaises(RuntimeError, lambda: op(s1, d2)) 835 836 # Skip inplace tests that would fail due to inability to cast to the output type. 837 # Some of these would also raise errors due to not being a supported op. 838 if inplace and not torch.can_cast(common_dtype, dtype1): 839 self.assertRaises(RuntimeError, lambda: op(dense1, sparse2)) 840 self.assertRaises(RuntimeError, lambda: op(sparse1, sparse2)) 841 self.assertRaises(RuntimeError, lambda: op(sparse1, dense2)) 842 return 843 844 expected = op(dense1.clone(), dense2) 845 precision = self._get_precision(expected.dtype, coalesced) 846 rtol = None if precision is None else 0 847 test_tensors = [expected, dense1, sparse1, dense2, sparse2] 848 e, d1, s1, d2, s2 = [x.clone() for x in test_tensors] if inplace else test_tensors 849 850 # Test op(sparse, sparse) 851 if op_name != 'div': 852 sparse = op(s1, s2) 853 self.assertEqual(sparse.dtype, e.dtype) 854 self.assertEqual(e, sparse.to_dense(), atol=precision, rtol=rtol, msg=err) 855 else: 856 # sparse division only supports division by a scalar 857 self.assertRaises(RuntimeError, lambda: op(s1, s2).to_dense()) 858 859 # Test op(dense, sparse) 860 if add_sub or op_name == 'mul': 861 if inplace: 862 e, d1, s1, d2, s2 = (x.clone() for x in test_tensors) 863 dense_sparse = op(d1, s2) 864 dense_sparse = dense_sparse.to_dense() if dense_sparse.is_sparse else dense_sparse 865 self.assertEqual(e, dense_sparse, atol=precision, rtol=rtol, msg=err) 866 else: 867 # sparse division only supports division by a scalar 868 # mul: Didn't find kernel to dispatch to for operator 'aten::_nnz' 869 self.assertRaises(RuntimeError, lambda: op(d1, s2)) 870 871 # Test op(sparse, dense) not supported for all ops but 'mul'. 872 # add(sparse, dense) is not supported. Use add(dense, sparse) instead. 873 # sparse division only supports division by a scalar 874 if op_name != 'mul': 875 self.assertRaises(RuntimeError, lambda: op(s1, d2)) 876 else: 877 # No type promotions for inplace operations, hence suf='' 878 op(s1, d2, suf='') 879 880 # Test op(sparse, scalar) 881 if not add_sub and not (self.device_type == 'cpu' and dtype1 == torch.half): 882 if inplace: 883 e, d1, s1, d2, s2 = (x.clone() for x in test_tensors) 884 scalar = d2.view(d2.numel())[0].item() 885 886 sparse = op(s1, scalar) 887 dense_scalar = op(d1, scalar) 888 self.assertEqual(sparse.dtype, dense_scalar.dtype) 889 self.assertEqual(dense_scalar, sparse.to_dense(), atol=precision, rtol=rtol, msg=err) 890 else: 891 # add(sparse, dense) is not supported. Use add(dense, sparse) instead. 892 # "mul_cpu" / "div_cpu" not implemented for 'Half' 893 self.assertRaises(RuntimeError, lambda: op(s1, d2.view(d2.numel())[0].item())) 894 895 def _run_all_tests_for_sparse_op(self, op_name, device, dtypes): 896 for dtype1, dtype2 in itertools.product(dtypes, dtypes): 897 for inplace, coalesced in itertools.product([True, False], [True, False]): 898 self._test_sparse_op(op_name, inplace, dtype1, dtype2, device, coalesced) 899 900 @onlyNativeDeviceTypes 901 def test_sparse_add(self, device): 902 self._run_all_tests_for_sparse_op('add', device, 903 dtypes=get_all_math_dtypes(device)) 904 905 @onlyNativeDeviceTypes 906 def test_sparse_mul(self, device): 907 self._run_all_tests_for_sparse_op('mul', device, 908 dtypes=get_all_math_dtypes(device)) 909 910 @onlyNativeDeviceTypes 911 def test_sparse_div(self, device): 912 self._run_all_tests_for_sparse_op('div', device, 913 dtypes=(torch.float32, torch.float64, 914 torch.complex64, torch.complex128)) 915 916 @onlyNativeDeviceTypes 917 def test_sparse_sub(self, device): 918 self._run_all_tests_for_sparse_op('sub', device, 919 dtypes=get_all_math_dtypes(device)) 920 921 @onlyNativeDeviceTypes 922 @dtypes(torch.bool, torch.short, torch.uint8, torch.int, torch.long) 923 @float_double_default_dtype 924 def test_sparse_div_promotion(self, device, dtype): 925 for op in (torch.div, torch.true_divide): 926 dividend = torch.randn(5, device=device).to(dtype) 927 divisor = 2 928 dividend_sparse = dividend.to_sparse() 929 casting_result = dividend.to(torch.get_default_dtype()) / 2 930 self.assertEqual(casting_result, op(dividend_sparse, 2).to_dense()) 931 932 @onlyNativeDeviceTypes 933 @dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64) 934 def test_integer_addcdiv_deprecated(self, device, dtype): 935 t = torch.tensor(1, device=device, dtype=dtype) 936 937 with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported.+'): 938 torch.addcdiv(t, t, t) 939 with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported.+'): 940 torch.addcdiv(t, t, t, out=t) 941 with self.assertRaisesRegex(RuntimeError, '^Integer division.+is no longer supported+'): 942 t.addcdiv_(t, t) 943 944 @unittest.skipIf(not TEST_NUMPY, "NumPy not found") 945 @float_double_default_dtype 946 @onlyCPU 947 # NB: skip uint16,32,64 as PyTorch doesn't implement promotion for them 948 @dtypes(*list(itertools.product( 949 set(numpy_to_torch_dtype_dict.values()) - {torch.uint16, torch.uint32, torch.uint64}, 950 set(numpy_to_torch_dtype_dict.values()) - {torch.uint16, torch.uint32, torch.uint64}))) 951 def test_numpy_array_binary_ufunc_promotion(self, device, dtypes): 952 import operator 953 np_type = torch_to_numpy_dtype_dict[dtypes[0]] 954 torch_type = dtypes[1] 955 956 t = torch.tensor((1,), device=device, dtype=torch_type) 957 a = np.array((1,), dtype=np_type) 958 a_as_t = torch.from_numpy(a).to(device=device) 959 960 for np_first in (True, False): 961 for op in (operator.add, torch.add): 962 963 # Acquires results of binary ufunc type promotion. 964 try: 965 actual = op(a, t) if np_first else op(t, a) 966 except Exception as e: 967 actual = e 968 969 try: 970 expected = op(a_as_t, t) if np_first else op(t, a_as_t) 971 except Exception as e: 972 expected = e 973 974 same_result = (type(expected) == type(actual)) and expected == actual 975 976 # Note: An "undesired failure," as opposed to an "expected failure" 977 # is both expected (we know the test will fail) and 978 # undesirable (if PyTorch was working properly the test would 979 # not fail). This test is affected by three issues (see below) 980 # that will cause undesired failures. It detects when these 981 # issues will occur and updates this bool accordingly. 982 undesired_failure = False 983 984 # A NumPy array as the first argument to the plus operator 985 # or as any argument to torch.add is not working as 986 # intended. 987 # See https://github.com/pytorch/pytorch/issues/36363. 988 if np_first and op is operator.add: 989 undesired_failure = True 990 if op is torch.add: 991 undesired_failure = True 992 993 # Expects the same result if undesired_failure is false 994 # and a different result otherwise. 995 # Note: These cases prettyprint the failing inputs to make 996 # debugging test failures easier. 997 if undesired_failure and same_result: 998 msg = ( 999 f"Failure: {actual} == {expected}. torch type was {torch_type}. " 1000 f"NumPy type was {np_type}. np_first is {np_first} default type is " 1001 f"{torch.get_default_dtype()}." 1002 ) 1003 self.fail(msg) 1004 1005 if not undesired_failure and not same_result: 1006 msg = ( 1007 f"Failure: {actual} != {expected}. torch type was {torch_type}. " 1008 f"NumPy type was {np_type}. np_first is {np_first} default type is " 1009 f"{torch.get_default_dtype()}." 1010 ) 1011 self.fail(msg) 1012 1013 1014 @onlyNativeDeviceTypes 1015 def test_cat_different_dtypes(self, device): 1016 dtypes = all_types_and_complex_and(torch.half, torch.bool) 1017 for x_dtype, y_dtype in itertools.product(dtypes, dtypes): 1018 x_vals, y_vals = [1, 2, 3], [4, 5, 6] 1019 1020 x = torch.tensor(x_vals, device=device, dtype=x_dtype) 1021 y = torch.tensor(y_vals, device=device, dtype=y_dtype) 1022 1023 if x_dtype is torch.bool: 1024 x_vals = [1, 1, 1] 1025 if y_dtype is torch.bool: 1026 y_vals = [1, 1, 1] 1027 1028 res_dtype = torch.result_type(x, y) 1029 expected_res = torch.tensor(x_vals + y_vals, device=device, dtype=res_dtype) 1030 res = torch.cat([x, y]) 1031 self.assertEqual(res, expected_res, exact_dtype=True) 1032 1033 # cat: full and an empty tensor. 1034 y = torch.tensor([], device=device, dtype=y_dtype) 1035 res_dtype = torch.result_type(x, y) 1036 expected_res = torch.tensor(x_vals + [], device=device, dtype=res_dtype) 1037 res = torch.cat([x, y]) 1038 self.assertEqual(res, expected_res, exact_dtype=True) 1039 1040 @onlyNativeDeviceTypes 1041 def test_cat_out_different_dtypes(self, device): 1042 dtypes = all_types_and_complex_and(torch.half) 1043 for x_dtype, y_dtype, out_dtype in itertools.product(dtypes, dtypes, dtypes): 1044 out = torch.zeros(6, device=device, dtype=out_dtype) 1045 x = torch.tensor([1, 2, 3], device=device, dtype=x_dtype) 1046 y = torch.tensor([4, 5, 6], device=device, dtype=y_dtype) 1047 expected_out = torch.tensor([1, 2, 3, 4, 5, 6], device=device, dtype=out_dtype) 1048 if (((x_dtype.is_floating_point or y_dtype.is_floating_point) 1049 and not (out_dtype.is_floating_point or out_dtype.is_complex)) 1050 or ((x_dtype.is_complex or y_dtype.is_complex) and not out_dtype.is_complex)): 1051 # This combinations do not support type conversion to a different class out type 1052 with self.assertRaises(RuntimeError): 1053 torch.cat([x, y], out=out) 1054 else: 1055 torch.cat([x, y], out=out) 1056 self.assertEqual(out, expected_out, exact_dtype=True) 1057 1058 # Verfies that unary ops require matching out types 1059 @onlyNativeDeviceTypes 1060 @dtypes(*itertools.product((torch.int64, 1061 torch.float32, torch.float64, 1062 torch.complex64, torch.complex128), 1063 (torch.int64, 1064 torch.float32, torch.float64, 1065 torch.complex64, torch.complex128))) 1066 def test_unary_op_out_casting(self, device, dtypes): 1067 t = torch.tensor((1), dtype=dtypes[0], device=device) 1068 out = torch.empty(0, dtype=dtypes[1], device=device) 1069 1070 ops = (torch.neg, torch.floor, torch.ceil) 1071 float_and_int_only_ops = {torch.floor, torch.ceil} 1072 real_only_ops = {torch.floor, torch.ceil} 1073 for op in ops: 1074 if dtypes[0] is not dtypes[1]: 1075 with self.assertRaises(RuntimeError): 1076 op(t, out=out) 1077 elif op in real_only_ops and dtypes[0].is_complex: 1078 with self.assertRaises(RuntimeError): 1079 op(t, out=out) 1080 elif ( 1081 op in float_and_int_only_ops 1082 and (not dtypes[0].is_floating_point and not dtypes[0].is_complex) 1083 and (not (dtypes[0] == torch.int64 and dtypes[1] == torch.int64)) 1084 and device != "meta" 1085 ): 1086 with self.assertRaises(RuntimeError): 1087 op(t, out=out) 1088 else: 1089 self.assertEqual(op(t, out=out), op(t)) 1090 self.assertEqual(op(t, out=out), out) 1091 1092 # Verifies that the out= argument doesn't affect the computation, that 1093 # is, out = op(...) and op(..., out=out) produce the same result. 1094 @onlyNativeDeviceTypes 1095 @skipMeta 1096 def test_computation_ignores_out(self, device): 1097 t = torch.tensor(33000, dtype=torch.float16, device=device) 1098 out = torch.empty(0, dtype=torch.float64, device=device) 1099 result = torch.add(t, t, out=out) 1100 self.assertEqual(result, t + t, exact_dtype=False) 1101 self.assertNotEqual(result, t.double() + t, exact_dtype=False) 1102 1103 a = torch.tensor(1.5, dtype=torch.float16, device=device) 1104 b = torch.tensor(.666, dtype=torch.float16, device=device) 1105 result = torch.true_divide(a, b, out=out) 1106 self.assertEqual(result, a / b, exact_dtype=False) 1107 self.assertNotEqual(result, a.double() / a, exact_dtype=False) 1108 1109 a = torch.tensor(5, dtype=torch.uint8, device=device) 1110 b = torch.tensor(8, dtype=torch.uint8, device=device) 1111 result = torch.sub(a, b, out=out) 1112 self.assertEqual(result, a - b, exact_dtype=False) 1113 self.assertNotEqual(result, a.double() - b, exact_dtype=False) 1114 1115 @onlyNativeDeviceTypes 1116 @dtypes(*itertools.product((torch.bool, torch.int, torch.float, torch.double), repeat=3)) 1117 def test_clamp_type_promotion(self, device, dtypes): 1118 dtype0, dtype1, dtype2 = dtypes 1119 S = 4 1120 1121 def make_tensor(size, dtype): 1122 if dtype == torch.bool: 1123 return torch.randint(2, size, dtype=dtype, device=device) 1124 elif dtype == torch.int: 1125 return torch.randint(10, size, dtype=dtype, device=device) 1126 else: 1127 return torch.randn(size, dtype=dtype, device=device) 1128 min_t = make_tensor((S,), dtype1) 1129 max_t = make_tensor((S,), dtype2) 1130 mins = (min_t, min_t[0], min_t[0].item()) 1131 maxs = (max_t, max_t[0], max_t[0].item()) 1132 inp = make_tensor((S,), dtype0) 1133 for min_v, max_v in itertools.product(mins, maxs): 1134 if type(max_v) != type(min_v): 1135 continue 1136 if isinstance(min_v, torch.Tensor) and min_v.ndim == 0 and max_v.ndim == 0: 1137 continue # 0d tensors go to scalar overload, and it's tested separately 1138 1139 def expected_type(inp, max, min): 1140 arg1, arg2 = max, min 1141 if isinstance(max, torch.Tensor) and max.ndim == 0: 1142 # first do a maybe dimensional boundary 1143 arg1, arg2 = min, max 1144 exp_type = torch.result_type(inp, arg1) 1145 inp_new = torch.empty_like(inp, dtype=exp_type) 1146 return torch.result_type(inp_new, arg2) 1147 exp_type = expected_type(inp, min_v, max_v) 1148 if exp_type != torch.bool: 1149 actual = torch.clamp(inp, min_v, max_v) 1150 inps = [x.to(exp_type) if isinstance(x, torch.Tensor) else x for x in (inp, min_v, max_v)] 1151 expected = torch.clamp(inps[0], inps[1], inps[2]) 1152 self.assertEqual(actual, expected) 1153 if inp.dtype in floating_types() or exp_type == inp.dtype: 1154 actual = torch.clamp_(inp, min_v, max_v) 1155 self.assertEqual(actual, expected, exact_dtype=False) 1156 for val in mins: 1157 def expected_type(inp, val): 1158 return torch.result_type(inp, val) 1159 exp_type = expected_type(inp, val) 1160 if exp_type != torch.bool: 1161 actual = torch.clamp_min(inp, val) 1162 inps = [x.to(exp_type) if isinstance(x, torch.Tensor) else x for x in (inp, val)] 1163 expected = torch.clamp_min(inps[0], inps[1]) 1164 self.assertEqual(actual.dtype, exp_type) 1165 self.assertEqual(actual, expected) 1166 if inp.dtype == exp_type: 1167 actual = torch.clamp_min_(inp, val) 1168 self.assertEqual(actual, expected) 1169 actual = torch.clamp_max(inp, val) 1170 expected = torch.clamp_max(inps[0], inps[1]) 1171 self.assertEqual(actual, expected) 1172 if inp.dtype in floating_types() or exp_type == inp.dtype: 1173 actual = torch.clamp_max_(inp, val) 1174 self.assertEqual(actual, expected, exact_dtype=False) 1175 1176 @onlyNativeDeviceTypes 1177 def test_ternary_out_promotion(self, device): 1178 for op in [torch.addcdiv, torch.addcmul]: 1179 for dtype in [torch.float32, torch.cfloat]: 1180 prom_dtype = torch.float64 if dtype is torch.float32 else torch.cdouble if dtype is torch.cfloat else dtype 1181 x = torch.rand(3, device=device, dtype=dtype) 1182 y = torch.empty(3, device=device, dtype=dtype) 1183 y_promo = torch.empty(3, device=device, dtype=prom_dtype) 1184 op(x, x, x, out=y) 1185 op(x, x, x, out=y_promo) 1186 self.assertEqual(y, y_promo.to(dtype=dtype)) 1187 1188 1189 1190 1191instantiate_device_type_tests(TestTypePromotion, globals()) 1192 1193if __name__ == '__main__': 1194 run_tests() 1195