# mypy: allow-untyped-defs # Owner(s): ["module: complex"] import torch from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, onlyCPU, ) from torch.testing._internal.common_dtype import complex_types from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase devices = (torch.device("cpu"), torch.device("cuda:0")) class TestComplexTensor(TestCase): @dtypes(*complex_types()) def test_to_list(self, device, dtype): # test that the complex float tensor has expected values and # there's no garbage value in the resultant list self.assertEqual( torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]], ) @dtypes(torch.float32, torch.float64, torch.float16) def test_dtype_inference(self, device, dtype): # issue: https://github.com/pytorch/pytorch/issues/36834 with set_default_dtype(dtype): x = torch.tensor([3.0, 3.0 + 5.0j], device=device) if dtype == torch.float16: self.assertEqual(x.dtype, torch.chalf) elif dtype == torch.float32: self.assertEqual(x.dtype, torch.cfloat) else: self.assertEqual(x.dtype, torch.cdouble) @dtypes(*complex_types()) def test_conj_copy(self, device, dtype): # issue: https://github.com/pytorch/pytorch/issues/106051 x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype) xc1 = torch.conj(x1) x1.copy_(xc1) self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype)) @dtypes(*complex_types()) def test_all(self, device, dtype): # issue: https://github.com/pytorch/pytorch/issues/120875 x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype) self.assertTrue(torch.all(x)) @dtypes(*complex_types()) def test_any(self, device, dtype): # issue: https://github.com/pytorch/pytorch/issues/120875 x = torch.tensor( [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype ) self.assertFalse(torch.any(x)) @onlyCPU @dtypes(*complex_types()) def test_eq(self, device, dtype): "Test eq on complex types" nan = float("nan") # Non-vectorized operations for a, b in ( ( torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype), ), ( torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype), ), ( torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype), ), ): actual = torch.eq(a, b) expected = torch.tensor([False], device=device, dtype=torch.bool) self.assertEqual( actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" ) actual = torch.eq(a, a) expected = torch.tensor([True], device=device, dtype=torch.bool) self.assertEqual( actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" ) actual = torch.full_like(b, complex(2, 2)) torch.eq(a, b, out=actual) expected = torch.tensor([complex(0)], device=device, dtype=dtype) self.assertEqual( actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" ) actual = torch.full_like(b, complex(2, 2)) torch.eq(a, a, out=actual) expected = torch.tensor([complex(1)], device=device, dtype=dtype) self.assertEqual( actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" ) # Vectorized operations for a, b in ( ( torch.tensor( [ -0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(2.8871, nan), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j, -0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(nan, -3.2650), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j, ], device=device, dtype=dtype, ), torch.tensor( [ -6.1278 - 8.5019j, 0.5886 + 8.8816j, complex(2.8871, nan), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j, -6.1278 - 2.1172j, 5.1576 + 8.8816j, complex(nan, -3.2650), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j, ], device=device, dtype=dtype, ), ), ): actual = torch.eq(a, b) expected = torch.tensor( [ False, False, False, False, False, True, False, False, False, False, False, True, ], device=device, dtype=torch.bool, ) self.assertEqual( actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" ) actual = torch.eq(a, a) expected = torch.tensor( [ True, True, False, True, True, True, True, True, False, True, True, True, ], device=device, dtype=torch.bool, ) self.assertEqual( actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" ) actual = torch.full_like(b, complex(2, 2)) torch.eq(a, b, out=actual) expected = torch.tensor( [ complex(0), complex(0), complex(0), complex(0), complex(0), complex(1), complex(0), complex(0), complex(0), complex(0), complex(0), complex(1), ], device=device, dtype=dtype, ) self.assertEqual( actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" ) actual = torch.full_like(b, complex(2, 2)) torch.eq(a, a, out=actual) expected = torch.tensor( [ complex(1), complex(1), complex(0), complex(1), complex(1), complex(1), complex(1), complex(1), complex(0), complex(1), complex(1), complex(1), ], device=device, dtype=dtype, ) self.assertEqual( actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" ) @onlyCPU @dtypes(*complex_types()) def test_ne(self, device, dtype): "Test ne on complex types" nan = float("nan") # Non-vectorized operations for a, b in ( ( torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype), ), ( torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype), ), ( torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype), ), ): actual = torch.ne(a, b) expected = torch.tensor([True], device=device, dtype=torch.bool) self.assertEqual( actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" ) actual = torch.ne(a, a) expected = torch.tensor([False], device=device, dtype=torch.bool) self.assertEqual( actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" ) actual = torch.full_like(b, complex(2, 2)) torch.ne(a, b, out=actual) expected = torch.tensor([complex(1)], device=device, dtype=dtype) self.assertEqual( actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" ) actual = torch.full_like(b, complex(2, 2)) torch.ne(a, a, out=actual) expected = torch.tensor([complex(0)], device=device, dtype=dtype) self.assertEqual( actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" ) # Vectorized operations for a, b in ( ( torch.tensor( [ -0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(2.8871, nan), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j, -0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(nan, -3.2650), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j, ], device=device, dtype=dtype, ), torch.tensor( [ -6.1278 - 8.5019j, 0.5886 + 8.8816j, complex(2.8871, nan), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j, -6.1278 - 2.1172j, 5.1576 + 8.8816j, complex(nan, -3.2650), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j, ], device=device, dtype=dtype, ), ), ): actual = torch.ne(a, b) expected = torch.tensor( [ True, True, True, True, True, False, True, True, True, True, True, False, ], device=device, dtype=torch.bool, ) self.assertEqual( actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" ) actual = torch.ne(a, a) expected = torch.tensor( [ False, False, True, False, False, False, False, False, True, False, False, False, ], device=device, dtype=torch.bool, ) self.assertEqual( actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" ) actual = torch.full_like(b, complex(2, 2)) torch.ne(a, b, out=actual) expected = torch.tensor( [ complex(1), complex(1), complex(1), complex(1), complex(1), complex(0), complex(1), complex(1), complex(1), complex(1), complex(1), complex(0), ], device=device, dtype=dtype, ) self.assertEqual( actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" ) actual = torch.full_like(b, complex(2, 2)) torch.ne(a, a, out=actual) expected = torch.tensor( [ complex(0), complex(0), complex(1), complex(0), complex(0), complex(0), complex(0), complex(0), complex(1), complex(0), complex(0), complex(0), ], device=device, dtype=dtype, ) self.assertEqual( actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" ) instantiate_device_type_tests(TestComplexTensor, globals()) if __name__ == "__main__": TestCase._default_dtype_check_enabled = True run_tests()