# Owner(s): ["oncall: jit"] import cmath import os import sys from itertools import product from textwrap import dedent from typing import Dict, List import torch from torch.testing._internal.common_utils import IS_MACOS from torch.testing._internal.jit_utils import execWrapper, JitTestCase # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) sys.path.append(pytorch_test_dir) class TestComplex(JitTestCase): def test_script(self): def fn(a: complex): return a self.checkScript(fn, (3 + 5j,)) def test_complexlist(self): def fn(a: List[complex], idx: int): return a[idx] input = [1j, 2, 3 + 4j, -5, -7j] self.checkScript(fn, (input, 2)) def test_complexdict(self): def fn(a: Dict[complex, complex], key: complex) -> complex: return a[key] input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j} self.checkScript(fn, (input, -4.3 - 2j)) def test_pickle(self): class ComplexModule(torch.jit.ScriptModule): def __init__(self) -> None: super().__init__() self.a = 3 + 5j self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j] self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j} @torch.jit.script_method def forward(self, b: int): return b + 2j loaded = self.getExportImportCopy(ComplexModule()) self.assertEqual(loaded.a, 3 + 5j) self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4]) self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}) self.assertEqual(loaded(2), 2 + 2j) def test_complex_parse(self): def fn(a: int, b: torch.Tensor, dim: int): # verifies `emitValueToTensor()` 's behavior b[dim] = 2.4 + 0.5j return (3 * 2j) + a + 5j - 7.4j - 4 t1 = torch.tensor(1) t2 = torch.tensor([0.4, 1.4j, 2.35]) self.checkScript(fn, (t1, t2, 2)) def test_complex_constants_and_ops(self): vals = ( [0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2] + [10.0**i for i in range(2)] + [-(10.0**i) for i in range(2)] ) complex_vals = tuple(complex(x, y) for x, y in product(vals, vals)) funcs_template = dedent( """ def func(a: complex): return cmath.{func_or_const}(a) """ ) def checkCmath(func_name, funcs_template=funcs_template): funcs_str = funcs_template.format(func_or_const=func_name) scope = {} execWrapper(funcs_str, globals(), scope) cu = torch.jit.CompilationUnit(funcs_str) f_script = cu.func f = scope["func"] if func_name in ["isinf", "isnan", "isfinite"]: new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")]) final_vals = tuple( complex(x, y) for x, y in product(new_vals, new_vals) ) else: final_vals = complex_vals for a in final_vals: res_python = None res_script = None try: res_python = f(a) except Exception as e: res_python = e try: res_script = f_script(a) except Exception as e: res_script = e if res_python != res_script: if isinstance(res_python, Exception): continue msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}" self.assertEqual(res_python, res_script, msg=msg) unary_ops = [ "log", "log10", "sqrt", "exp", "sin", "cos", "asin", "acos", "atan", "sinh", "cosh", "tanh", "asinh", "acosh", "atanh", "phase", "isinf", "isnan", "isfinite", ] # --- Unary ops --- for op in unary_ops: checkCmath(op) def fn(x: complex): return abs(x) for val in complex_vals: self.checkScript(fn, (val,)) def pow_complex_float(x: complex, y: float): return pow(x, y) def pow_float_complex(x: float, y: complex): return pow(x, y) self.checkScript(pow_float_complex, (2, 3j)) self.checkScript(pow_complex_float, (3j, 2)) def pow_complex_complex(x: complex, y: complex): return pow(x, y) for x, y in zip(complex_vals, complex_vals): # Reference: https://github.com/pytorch/pytorch/issues/54622 if x == 0: continue self.checkScript(pow_complex_complex, (x, y)) if not IS_MACOS: # --- Binary op --- def rect_fn(x: float, y: float): return cmath.rect(x, y) for x, y in product(vals, vals): self.checkScript( rect_fn, ( x, y, ), ) func_constants_template = dedent( """ def func(): return cmath.{func_or_const} """ ) float_consts = ["pi", "e", "tau", "inf", "nan"] complex_consts = ["infj", "nanj"] for x in float_consts + complex_consts: checkCmath(x, funcs_template=func_constants_template) def test_infj_nanj_pickle(self): class ComplexModule(torch.jit.ScriptModule): def __init__(self) -> None: super().__init__() self.a = 3 + 5j @torch.jit.script_method def forward(self, infj: int, nanj: int): if infj == 2: return infj + cmath.infj else: return nanj + cmath.nanj loaded = self.getExportImportCopy(ComplexModule()) self.assertEqual(loaded(2, 3), 2 + cmath.infj) self.assertEqual(loaded(3, 4), 4 + cmath.nanj) def test_complex_constructor(self): # Test all scalar types def fn_int(real: int, img: int): return complex(real, img) self.checkScript( fn_int, ( 0, 0, ), ) self.checkScript( fn_int, ( -1234, 0, ), ) self.checkScript( fn_int, ( 0, -1256, ), ) self.checkScript( fn_int, ( -167, -1256, ), ) def fn_float(real: float, img: float): return complex(real, img) self.checkScript( fn_float, ( 0.0, 0.0, ), ) self.checkScript( fn_float, ( -1234.78, 0, ), ) self.checkScript( fn_float, ( 0, 56.18, ), ) self.checkScript( fn_float, ( -1.9, -19.8, ), ) def fn_bool(real: bool, img: bool): return complex(real, img) self.checkScript( fn_bool, ( True, True, ), ) self.checkScript( fn_bool, ( False, False, ), ) self.checkScript( fn_bool, ( False, True, ), ) self.checkScript( fn_bool, ( True, False, ), ) def fn_bool_int(real: bool, img: int): return complex(real, img) self.checkScript( fn_bool_int, ( True, 0, ), ) self.checkScript( fn_bool_int, ( False, 0, ), ) self.checkScript( fn_bool_int, ( False, -1, ), ) self.checkScript( fn_bool_int, ( True, 3, ), ) def fn_int_bool(real: int, img: bool): return complex(real, img) self.checkScript( fn_int_bool, ( 0, True, ), ) self.checkScript( fn_int_bool, ( 0, False, ), ) self.checkScript( fn_int_bool, ( -3, True, ), ) self.checkScript( fn_int_bool, ( 6, False, ), ) def fn_bool_float(real: bool, img: float): return complex(real, img) self.checkScript( fn_bool_float, ( True, 0.0, ), ) self.checkScript( fn_bool_float, ( False, 0.0, ), ) self.checkScript( fn_bool_float, ( False, -1.0, ), ) self.checkScript( fn_bool_float, ( True, 3.0, ), ) def fn_float_bool(real: float, img: bool): return complex(real, img) self.checkScript( fn_float_bool, ( 0.0, True, ), ) self.checkScript( fn_float_bool, ( 0.0, False, ), ) self.checkScript( fn_float_bool, ( -3.0, True, ), ) self.checkScript( fn_float_bool, ( 6.0, False, ), ) def fn_float_int(real: float, img: int): return complex(real, img) self.checkScript( fn_float_int, ( 0.0, 1, ), ) self.checkScript( fn_float_int, ( 0.0, -1, ), ) self.checkScript( fn_float_int, ( 1.8, -3, ), ) self.checkScript( fn_float_int, ( 2.7, 8, ), ) def fn_int_float(real: int, img: float): return complex(real, img) self.checkScript( fn_int_float, ( 1, 0.0, ), ) self.checkScript( fn_int_float, ( -1, 1.7, ), ) self.checkScript( fn_int_float, ( -3, 0.0, ), ) self.checkScript( fn_int_float, ( 2, -8.9, ), ) def test_torch_complex_constructor_with_tensor(self): tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])] def fn_tensor_float(real, img: float): return complex(real, img) def fn_tensor_int(real, img: int): return complex(real, img) def fn_tensor_bool(real, img: bool): return complex(real, img) def fn_float_tensor(real: float, img): return complex(real, img) def fn_int_tensor(real: int, img): return complex(real, img) def fn_bool_tensor(real: bool, img): return complex(real, img) for tensor in tensors: self.checkScript(fn_tensor_float, (tensor, 1.2)) self.checkScript(fn_tensor_int, (tensor, 3)) self.checkScript(fn_tensor_bool, (tensor, True)) self.checkScript(fn_float_tensor, (1.2, tensor)) self.checkScript(fn_int_tensor, (3, tensor)) self.checkScript(fn_bool_tensor, (True, tensor)) def fn_tensor_tensor(real, img): return complex(real, img) + complex(2) for x, y in product(tensors, tensors): self.checkScript( fn_tensor_tensor, ( x, y, ), ) def test_comparison_ops(self): def fn1(a: complex, b: complex): return a == b def fn2(a: complex, b: complex): return a != b def fn3(a: complex, b: float): return a == b def fn4(a: complex, b: float): return a != b x, y = 2 - 3j, 4j self.checkScript(fn1, (x, x)) self.checkScript(fn1, (x, y)) self.checkScript(fn2, (x, x)) self.checkScript(fn2, (x, y)) x1, y1 = 1 + 0j, 1.0 self.checkScript(fn3, (x1, y1)) self.checkScript(fn4, (x1, y1)) def test_div(self): def fn1(a: complex, b: complex): return a / b x, y = 2 - 3j, 4j self.checkScript(fn1, (x, y)) def test_complex_list_sum(self): def fn(x: List[complex]): return sum(x) self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),)) def test_tensor_attributes(self): def tensor_real(x): return x.real def tensor_imag(x): return x.imag t = torch.randn(2, 3, dtype=torch.cdouble) self.checkScript(tensor_real, (t,)) self.checkScript(tensor_imag, (t,)) def test_binary_op_complex_tensor(self): def mul(x: complex, y: torch.Tensor): return x * y def add(x: complex, y: torch.Tensor): return x + y def eq(x: complex, y: torch.Tensor): return x == y def ne(x: complex, y: torch.Tensor): return x != y def sub(x: complex, y: torch.Tensor): return x - y def div(x: complex, y: torch.Tensor): return x - y ops = [mul, add, eq, ne, sub, div] for shape in [(1,), (2, 2)]: x = 0.71 + 0.71j y = torch.randn(shape, dtype=torch.cfloat) for op in ops: eager_result = op(x, y) scripted = torch.jit.script(op) jit_result = scripted(x, y) self.assertEqual(eager_result, jit_result)