1# Owner(s): ["oncall: jit"] 2 3import cmath 4import os 5import sys 6from itertools import product 7from textwrap import dedent 8from typing import Dict, List 9 10import torch 11from torch.testing._internal.common_utils import IS_MACOS 12from torch.testing._internal.jit_utils import execWrapper, JitTestCase 13 14 15# Make the helper files in test/ importable 16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 17sys.path.append(pytorch_test_dir) 18 19 20class TestComplex(JitTestCase): 21 def test_script(self): 22 def fn(a: complex): 23 return a 24 25 self.checkScript(fn, (3 + 5j,)) 26 27 def test_complexlist(self): 28 def fn(a: List[complex], idx: int): 29 return a[idx] 30 31 input = [1j, 2, 3 + 4j, -5, -7j] 32 self.checkScript(fn, (input, 2)) 33 34 def test_complexdict(self): 35 def fn(a: Dict[complex, complex], key: complex) -> complex: 36 return a[key] 37 38 input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j} 39 self.checkScript(fn, (input, -4.3 - 2j)) 40 41 def test_pickle(self): 42 class ComplexModule(torch.jit.ScriptModule): 43 def __init__(self) -> None: 44 super().__init__() 45 self.a = 3 + 5j 46 self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j] 47 self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j} 48 49 @torch.jit.script_method 50 def forward(self, b: int): 51 return b + 2j 52 53 loaded = self.getExportImportCopy(ComplexModule()) 54 self.assertEqual(loaded.a, 3 + 5j) 55 self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4]) 56 self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}) 57 self.assertEqual(loaded(2), 2 + 2j) 58 59 def test_complex_parse(self): 60 def fn(a: int, b: torch.Tensor, dim: int): 61 # verifies `emitValueToTensor()` 's behavior 62 b[dim] = 2.4 + 0.5j 63 return (3 * 2j) + a + 5j - 7.4j - 4 64 65 t1 = torch.tensor(1) 66 t2 = torch.tensor([0.4, 1.4j, 2.35]) 67 68 self.checkScript(fn, (t1, t2, 2)) 69 70 def test_complex_constants_and_ops(self): 71 vals = ( 72 [0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2] 73 + [10.0**i for i in range(2)] 74 + [-(10.0**i) for i in range(2)] 75 ) 76 complex_vals = tuple(complex(x, y) for x, y in product(vals, vals)) 77 78 funcs_template = dedent( 79 """ 80 def func(a: complex): 81 return cmath.{func_or_const}(a) 82 """ 83 ) 84 85 def checkCmath(func_name, funcs_template=funcs_template): 86 funcs_str = funcs_template.format(func_or_const=func_name) 87 scope = {} 88 execWrapper(funcs_str, globals(), scope) 89 cu = torch.jit.CompilationUnit(funcs_str) 90 f_script = cu.func 91 f = scope["func"] 92 93 if func_name in ["isinf", "isnan", "isfinite"]: 94 new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")]) 95 final_vals = tuple( 96 complex(x, y) for x, y in product(new_vals, new_vals) 97 ) 98 else: 99 final_vals = complex_vals 100 101 for a in final_vals: 102 res_python = None 103 res_script = None 104 try: 105 res_python = f(a) 106 except Exception as e: 107 res_python = e 108 try: 109 res_script = f_script(a) 110 except Exception as e: 111 res_script = e 112 113 if res_python != res_script: 114 if isinstance(res_python, Exception): 115 continue 116 117 msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}" 118 self.assertEqual(res_python, res_script, msg=msg) 119 120 unary_ops = [ 121 "log", 122 "log10", 123 "sqrt", 124 "exp", 125 "sin", 126 "cos", 127 "asin", 128 "acos", 129 "atan", 130 "sinh", 131 "cosh", 132 "tanh", 133 "asinh", 134 "acosh", 135 "atanh", 136 "phase", 137 "isinf", 138 "isnan", 139 "isfinite", 140 ] 141 142 # --- Unary ops --- 143 for op in unary_ops: 144 checkCmath(op) 145 146 def fn(x: complex): 147 return abs(x) 148 149 for val in complex_vals: 150 self.checkScript(fn, (val,)) 151 152 def pow_complex_float(x: complex, y: float): 153 return pow(x, y) 154 155 def pow_float_complex(x: float, y: complex): 156 return pow(x, y) 157 158 self.checkScript(pow_float_complex, (2, 3j)) 159 self.checkScript(pow_complex_float, (3j, 2)) 160 161 def pow_complex_complex(x: complex, y: complex): 162 return pow(x, y) 163 164 for x, y in zip(complex_vals, complex_vals): 165 # Reference: https://github.com/pytorch/pytorch/issues/54622 166 if x == 0: 167 continue 168 self.checkScript(pow_complex_complex, (x, y)) 169 170 if not IS_MACOS: 171 # --- Binary op --- 172 def rect_fn(x: float, y: float): 173 return cmath.rect(x, y) 174 175 for x, y in product(vals, vals): 176 self.checkScript( 177 rect_fn, 178 ( 179 x, 180 y, 181 ), 182 ) 183 184 func_constants_template = dedent( 185 """ 186 def func(): 187 return cmath.{func_or_const} 188 """ 189 ) 190 float_consts = ["pi", "e", "tau", "inf", "nan"] 191 complex_consts = ["infj", "nanj"] 192 for x in float_consts + complex_consts: 193 checkCmath(x, funcs_template=func_constants_template) 194 195 def test_infj_nanj_pickle(self): 196 class ComplexModule(torch.jit.ScriptModule): 197 def __init__(self) -> None: 198 super().__init__() 199 self.a = 3 + 5j 200 201 @torch.jit.script_method 202 def forward(self, infj: int, nanj: int): 203 if infj == 2: 204 return infj + cmath.infj 205 else: 206 return nanj + cmath.nanj 207 208 loaded = self.getExportImportCopy(ComplexModule()) 209 self.assertEqual(loaded(2, 3), 2 + cmath.infj) 210 self.assertEqual(loaded(3, 4), 4 + cmath.nanj) 211 212 def test_complex_constructor(self): 213 # Test all scalar types 214 def fn_int(real: int, img: int): 215 return complex(real, img) 216 217 self.checkScript( 218 fn_int, 219 ( 220 0, 221 0, 222 ), 223 ) 224 self.checkScript( 225 fn_int, 226 ( 227 -1234, 228 0, 229 ), 230 ) 231 self.checkScript( 232 fn_int, 233 ( 234 0, 235 -1256, 236 ), 237 ) 238 self.checkScript( 239 fn_int, 240 ( 241 -167, 242 -1256, 243 ), 244 ) 245 246 def fn_float(real: float, img: float): 247 return complex(real, img) 248 249 self.checkScript( 250 fn_float, 251 ( 252 0.0, 253 0.0, 254 ), 255 ) 256 self.checkScript( 257 fn_float, 258 ( 259 -1234.78, 260 0, 261 ), 262 ) 263 self.checkScript( 264 fn_float, 265 ( 266 0, 267 56.18, 268 ), 269 ) 270 self.checkScript( 271 fn_float, 272 ( 273 -1.9, 274 -19.8, 275 ), 276 ) 277 278 def fn_bool(real: bool, img: bool): 279 return complex(real, img) 280 281 self.checkScript( 282 fn_bool, 283 ( 284 True, 285 True, 286 ), 287 ) 288 self.checkScript( 289 fn_bool, 290 ( 291 False, 292 False, 293 ), 294 ) 295 self.checkScript( 296 fn_bool, 297 ( 298 False, 299 True, 300 ), 301 ) 302 self.checkScript( 303 fn_bool, 304 ( 305 True, 306 False, 307 ), 308 ) 309 310 def fn_bool_int(real: bool, img: int): 311 return complex(real, img) 312 313 self.checkScript( 314 fn_bool_int, 315 ( 316 True, 317 0, 318 ), 319 ) 320 self.checkScript( 321 fn_bool_int, 322 ( 323 False, 324 0, 325 ), 326 ) 327 self.checkScript( 328 fn_bool_int, 329 ( 330 False, 331 -1, 332 ), 333 ) 334 self.checkScript( 335 fn_bool_int, 336 ( 337 True, 338 3, 339 ), 340 ) 341 342 def fn_int_bool(real: int, img: bool): 343 return complex(real, img) 344 345 self.checkScript( 346 fn_int_bool, 347 ( 348 0, 349 True, 350 ), 351 ) 352 self.checkScript( 353 fn_int_bool, 354 ( 355 0, 356 False, 357 ), 358 ) 359 self.checkScript( 360 fn_int_bool, 361 ( 362 -3, 363 True, 364 ), 365 ) 366 self.checkScript( 367 fn_int_bool, 368 ( 369 6, 370 False, 371 ), 372 ) 373 374 def fn_bool_float(real: bool, img: float): 375 return complex(real, img) 376 377 self.checkScript( 378 fn_bool_float, 379 ( 380 True, 381 0.0, 382 ), 383 ) 384 self.checkScript( 385 fn_bool_float, 386 ( 387 False, 388 0.0, 389 ), 390 ) 391 self.checkScript( 392 fn_bool_float, 393 ( 394 False, 395 -1.0, 396 ), 397 ) 398 self.checkScript( 399 fn_bool_float, 400 ( 401 True, 402 3.0, 403 ), 404 ) 405 406 def fn_float_bool(real: float, img: bool): 407 return complex(real, img) 408 409 self.checkScript( 410 fn_float_bool, 411 ( 412 0.0, 413 True, 414 ), 415 ) 416 self.checkScript( 417 fn_float_bool, 418 ( 419 0.0, 420 False, 421 ), 422 ) 423 self.checkScript( 424 fn_float_bool, 425 ( 426 -3.0, 427 True, 428 ), 429 ) 430 self.checkScript( 431 fn_float_bool, 432 ( 433 6.0, 434 False, 435 ), 436 ) 437 438 def fn_float_int(real: float, img: int): 439 return complex(real, img) 440 441 self.checkScript( 442 fn_float_int, 443 ( 444 0.0, 445 1, 446 ), 447 ) 448 self.checkScript( 449 fn_float_int, 450 ( 451 0.0, 452 -1, 453 ), 454 ) 455 self.checkScript( 456 fn_float_int, 457 ( 458 1.8, 459 -3, 460 ), 461 ) 462 self.checkScript( 463 fn_float_int, 464 ( 465 2.7, 466 8, 467 ), 468 ) 469 470 def fn_int_float(real: int, img: float): 471 return complex(real, img) 472 473 self.checkScript( 474 fn_int_float, 475 ( 476 1, 477 0.0, 478 ), 479 ) 480 self.checkScript( 481 fn_int_float, 482 ( 483 -1, 484 1.7, 485 ), 486 ) 487 self.checkScript( 488 fn_int_float, 489 ( 490 -3, 491 0.0, 492 ), 493 ) 494 self.checkScript( 495 fn_int_float, 496 ( 497 2, 498 -8.9, 499 ), 500 ) 501 502 def test_torch_complex_constructor_with_tensor(self): 503 tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])] 504 505 def fn_tensor_float(real, img: float): 506 return complex(real, img) 507 508 def fn_tensor_int(real, img: int): 509 return complex(real, img) 510 511 def fn_tensor_bool(real, img: bool): 512 return complex(real, img) 513 514 def fn_float_tensor(real: float, img): 515 return complex(real, img) 516 517 def fn_int_tensor(real: int, img): 518 return complex(real, img) 519 520 def fn_bool_tensor(real: bool, img): 521 return complex(real, img) 522 523 for tensor in tensors: 524 self.checkScript(fn_tensor_float, (tensor, 1.2)) 525 self.checkScript(fn_tensor_int, (tensor, 3)) 526 self.checkScript(fn_tensor_bool, (tensor, True)) 527 528 self.checkScript(fn_float_tensor, (1.2, tensor)) 529 self.checkScript(fn_int_tensor, (3, tensor)) 530 self.checkScript(fn_bool_tensor, (True, tensor)) 531 532 def fn_tensor_tensor(real, img): 533 return complex(real, img) + complex(2) 534 535 for x, y in product(tensors, tensors): 536 self.checkScript( 537 fn_tensor_tensor, 538 ( 539 x, 540 y, 541 ), 542 ) 543 544 def test_comparison_ops(self): 545 def fn1(a: complex, b: complex): 546 return a == b 547 548 def fn2(a: complex, b: complex): 549 return a != b 550 551 def fn3(a: complex, b: float): 552 return a == b 553 554 def fn4(a: complex, b: float): 555 return a != b 556 557 x, y = 2 - 3j, 4j 558 self.checkScript(fn1, (x, x)) 559 self.checkScript(fn1, (x, y)) 560 self.checkScript(fn2, (x, x)) 561 self.checkScript(fn2, (x, y)) 562 563 x1, y1 = 1 + 0j, 1.0 564 self.checkScript(fn3, (x1, y1)) 565 self.checkScript(fn4, (x1, y1)) 566 567 def test_div(self): 568 def fn1(a: complex, b: complex): 569 return a / b 570 571 x, y = 2 - 3j, 4j 572 self.checkScript(fn1, (x, y)) 573 574 def test_complex_list_sum(self): 575 def fn(x: List[complex]): 576 return sum(x) 577 578 self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),)) 579 580 def test_tensor_attributes(self): 581 def tensor_real(x): 582 return x.real 583 584 def tensor_imag(x): 585 return x.imag 586 587 t = torch.randn(2, 3, dtype=torch.cdouble) 588 self.checkScript(tensor_real, (t,)) 589 self.checkScript(tensor_imag, (t,)) 590 591 def test_binary_op_complex_tensor(self): 592 def mul(x: complex, y: torch.Tensor): 593 return x * y 594 595 def add(x: complex, y: torch.Tensor): 596 return x + y 597 598 def eq(x: complex, y: torch.Tensor): 599 return x == y 600 601 def ne(x: complex, y: torch.Tensor): 602 return x != y 603 604 def sub(x: complex, y: torch.Tensor): 605 return x - y 606 607 def div(x: complex, y: torch.Tensor): 608 return x - y 609 610 ops = [mul, add, eq, ne, sub, div] 611 612 for shape in [(1,), (2, 2)]: 613 x = 0.71 + 0.71j 614 y = torch.randn(shape, dtype=torch.cfloat) 615 for op in ops: 616 eager_result = op(x, y) 617 scripted = torch.jit.script(op) 618 jit_result = scripted(x, y) 619 self.assertEqual(eager_result, jit_result) 620