1# mypy: allow-untyped-defs 2# Owner(s): ["module: complex"] 3 4import torch 5from torch.testing._internal.common_device_type import ( 6 dtypes, 7 instantiate_device_type_tests, 8 onlyCPU, 9) 10from torch.testing._internal.common_dtype import complex_types 11from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase 12 13 14devices = (torch.device("cpu"), torch.device("cuda:0")) 15 16 17class TestComplexTensor(TestCase): 18 @dtypes(*complex_types()) 19 def test_to_list(self, device, dtype): 20 # test that the complex float tensor has expected values and 21 # there's no garbage value in the resultant list 22 self.assertEqual( 23 torch.zeros((2, 2), device=device, dtype=dtype).tolist(), 24 [[0j, 0j], [0j, 0j]], 25 ) 26 27 @dtypes(torch.float32, torch.float64, torch.float16) 28 def test_dtype_inference(self, device, dtype): 29 # issue: https://github.com/pytorch/pytorch/issues/36834 30 with set_default_dtype(dtype): 31 x = torch.tensor([3.0, 3.0 + 5.0j], device=device) 32 if dtype == torch.float16: 33 self.assertEqual(x.dtype, torch.chalf) 34 elif dtype == torch.float32: 35 self.assertEqual(x.dtype, torch.cfloat) 36 else: 37 self.assertEqual(x.dtype, torch.cdouble) 38 39 @dtypes(*complex_types()) 40 def test_conj_copy(self, device, dtype): 41 # issue: https://github.com/pytorch/pytorch/issues/106051 42 x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype) 43 xc1 = torch.conj(x1) 44 x1.copy_(xc1) 45 self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype)) 46 47 @dtypes(*complex_types()) 48 def test_all(self, device, dtype): 49 # issue: https://github.com/pytorch/pytorch/issues/120875 50 x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype) 51 self.assertTrue(torch.all(x)) 52 53 @dtypes(*complex_types()) 54 def test_any(self, device, dtype): 55 # issue: https://github.com/pytorch/pytorch/issues/120875 56 x = torch.tensor( 57 [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype 58 ) 59 self.assertFalse(torch.any(x)) 60 61 @onlyCPU 62 @dtypes(*complex_types()) 63 def test_eq(self, device, dtype): 64 "Test eq on complex types" 65 nan = float("nan") 66 # Non-vectorized operations 67 for a, b in ( 68 ( 69 torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 70 torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype), 71 ), 72 ( 73 torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 74 torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype), 75 ), 76 ( 77 torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 78 torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype), 79 ), 80 ): 81 actual = torch.eq(a, b) 82 expected = torch.tensor([False], device=device, dtype=torch.bool) 83 self.assertEqual( 84 actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" 85 ) 86 87 actual = torch.eq(a, a) 88 expected = torch.tensor([True], device=device, dtype=torch.bool) 89 self.assertEqual( 90 actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" 91 ) 92 93 actual = torch.full_like(b, complex(2, 2)) 94 torch.eq(a, b, out=actual) 95 expected = torch.tensor([complex(0)], device=device, dtype=dtype) 96 self.assertEqual( 97 actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" 98 ) 99 100 actual = torch.full_like(b, complex(2, 2)) 101 torch.eq(a, a, out=actual) 102 expected = torch.tensor([complex(1)], device=device, dtype=dtype) 103 self.assertEqual( 104 actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" 105 ) 106 107 # Vectorized operations 108 for a, b in ( 109 ( 110 torch.tensor( 111 [ 112 -0.0610 - 2.1172j, 113 5.1576 + 5.4775j, 114 complex(2.8871, nan), 115 -6.6545 - 3.7655j, 116 -2.7036 - 1.4470j, 117 0.3712 + 7.989j, 118 -0.0610 - 2.1172j, 119 5.1576 + 5.4775j, 120 complex(nan, -3.2650), 121 -6.6545 - 3.7655j, 122 -2.7036 - 1.4470j, 123 0.3712 + 7.989j, 124 ], 125 device=device, 126 dtype=dtype, 127 ), 128 torch.tensor( 129 [ 130 -6.1278 - 8.5019j, 131 0.5886 + 8.8816j, 132 complex(2.8871, nan), 133 6.3505 + 2.2683j, 134 0.3712 + 7.9659j, 135 0.3712 + 7.989j, 136 -6.1278 - 2.1172j, 137 5.1576 + 8.8816j, 138 complex(nan, -3.2650), 139 6.3505 + 2.2683j, 140 0.3712 + 7.9659j, 141 0.3712 + 7.989j, 142 ], 143 device=device, 144 dtype=dtype, 145 ), 146 ), 147 ): 148 actual = torch.eq(a, b) 149 expected = torch.tensor( 150 [ 151 False, 152 False, 153 False, 154 False, 155 False, 156 True, 157 False, 158 False, 159 False, 160 False, 161 False, 162 True, 163 ], 164 device=device, 165 dtype=torch.bool, 166 ) 167 self.assertEqual( 168 actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" 169 ) 170 171 actual = torch.eq(a, a) 172 expected = torch.tensor( 173 [ 174 True, 175 True, 176 False, 177 True, 178 True, 179 True, 180 True, 181 True, 182 False, 183 True, 184 True, 185 True, 186 ], 187 device=device, 188 dtype=torch.bool, 189 ) 190 self.assertEqual( 191 actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}" 192 ) 193 194 actual = torch.full_like(b, complex(2, 2)) 195 torch.eq(a, b, out=actual) 196 expected = torch.tensor( 197 [ 198 complex(0), 199 complex(0), 200 complex(0), 201 complex(0), 202 complex(0), 203 complex(1), 204 complex(0), 205 complex(0), 206 complex(0), 207 complex(0), 208 complex(0), 209 complex(1), 210 ], 211 device=device, 212 dtype=dtype, 213 ) 214 self.assertEqual( 215 actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" 216 ) 217 218 actual = torch.full_like(b, complex(2, 2)) 219 torch.eq(a, a, out=actual) 220 expected = torch.tensor( 221 [ 222 complex(1), 223 complex(1), 224 complex(0), 225 complex(1), 226 complex(1), 227 complex(1), 228 complex(1), 229 complex(1), 230 complex(0), 231 complex(1), 232 complex(1), 233 complex(1), 234 ], 235 device=device, 236 dtype=dtype, 237 ) 238 self.assertEqual( 239 actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}" 240 ) 241 242 @onlyCPU 243 @dtypes(*complex_types()) 244 def test_ne(self, device, dtype): 245 "Test ne on complex types" 246 nan = float("nan") 247 # Non-vectorized operations 248 for a, b in ( 249 ( 250 torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 251 torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype), 252 ), 253 ( 254 torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 255 torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype), 256 ), 257 ( 258 torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype), 259 torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype), 260 ), 261 ): 262 actual = torch.ne(a, b) 263 expected = torch.tensor([True], device=device, dtype=torch.bool) 264 self.assertEqual( 265 actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" 266 ) 267 268 actual = torch.ne(a, a) 269 expected = torch.tensor([False], device=device, dtype=torch.bool) 270 self.assertEqual( 271 actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" 272 ) 273 274 actual = torch.full_like(b, complex(2, 2)) 275 torch.ne(a, b, out=actual) 276 expected = torch.tensor([complex(1)], device=device, dtype=dtype) 277 self.assertEqual( 278 actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" 279 ) 280 281 actual = torch.full_like(b, complex(2, 2)) 282 torch.ne(a, a, out=actual) 283 expected = torch.tensor([complex(0)], device=device, dtype=dtype) 284 self.assertEqual( 285 actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" 286 ) 287 288 # Vectorized operations 289 for a, b in ( 290 ( 291 torch.tensor( 292 [ 293 -0.0610 - 2.1172j, 294 5.1576 + 5.4775j, 295 complex(2.8871, nan), 296 -6.6545 - 3.7655j, 297 -2.7036 - 1.4470j, 298 0.3712 + 7.989j, 299 -0.0610 - 2.1172j, 300 5.1576 + 5.4775j, 301 complex(nan, -3.2650), 302 -6.6545 - 3.7655j, 303 -2.7036 - 1.4470j, 304 0.3712 + 7.989j, 305 ], 306 device=device, 307 dtype=dtype, 308 ), 309 torch.tensor( 310 [ 311 -6.1278 - 8.5019j, 312 0.5886 + 8.8816j, 313 complex(2.8871, nan), 314 6.3505 + 2.2683j, 315 0.3712 + 7.9659j, 316 0.3712 + 7.989j, 317 -6.1278 - 2.1172j, 318 5.1576 + 8.8816j, 319 complex(nan, -3.2650), 320 6.3505 + 2.2683j, 321 0.3712 + 7.9659j, 322 0.3712 + 7.989j, 323 ], 324 device=device, 325 dtype=dtype, 326 ), 327 ), 328 ): 329 actual = torch.ne(a, b) 330 expected = torch.tensor( 331 [ 332 True, 333 True, 334 True, 335 True, 336 True, 337 False, 338 True, 339 True, 340 True, 341 True, 342 True, 343 False, 344 ], 345 device=device, 346 dtype=torch.bool, 347 ) 348 self.assertEqual( 349 actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" 350 ) 351 352 actual = torch.ne(a, a) 353 expected = torch.tensor( 354 [ 355 False, 356 False, 357 True, 358 False, 359 False, 360 False, 361 False, 362 False, 363 True, 364 False, 365 False, 366 False, 367 ], 368 device=device, 369 dtype=torch.bool, 370 ) 371 self.assertEqual( 372 actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}" 373 ) 374 375 actual = torch.full_like(b, complex(2, 2)) 376 torch.ne(a, b, out=actual) 377 expected = torch.tensor( 378 [ 379 complex(1), 380 complex(1), 381 complex(1), 382 complex(1), 383 complex(1), 384 complex(0), 385 complex(1), 386 complex(1), 387 complex(1), 388 complex(1), 389 complex(1), 390 complex(0), 391 ], 392 device=device, 393 dtype=dtype, 394 ) 395 self.assertEqual( 396 actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" 397 ) 398 399 actual = torch.full_like(b, complex(2, 2)) 400 torch.ne(a, a, out=actual) 401 expected = torch.tensor( 402 [ 403 complex(0), 404 complex(0), 405 complex(1), 406 complex(0), 407 complex(0), 408 complex(0), 409 complex(0), 410 complex(0), 411 complex(1), 412 complex(0), 413 complex(0), 414 complex(0), 415 ], 416 device=device, 417 dtype=dtype, 418 ) 419 self.assertEqual( 420 actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}" 421 ) 422 423 424instantiate_device_type_tests(TestComplexTensor, globals()) 425 426if __name__ == "__main__": 427 TestCase._default_dtype_check_enabled = True 428 run_tests() 429