xref: /aosp_15_r20/external/pytorch/test/jit/test_complex.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import cmath
4import os
5import sys
6from itertools import product
7from textwrap import dedent
8from typing import Dict, List
9
10import torch
11from torch.testing._internal.common_utils import IS_MACOS
12from torch.testing._internal.jit_utils import execWrapper, JitTestCase
13
14
15# Make the helper files in test/ importable
16pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
17sys.path.append(pytorch_test_dir)
18
19
20class TestComplex(JitTestCase):
21    def test_script(self):
22        def fn(a: complex):
23            return a
24
25        self.checkScript(fn, (3 + 5j,))
26
27    def test_complexlist(self):
28        def fn(a: List[complex], idx: int):
29            return a[idx]
30
31        input = [1j, 2, 3 + 4j, -5, -7j]
32        self.checkScript(fn, (input, 2))
33
34    def test_complexdict(self):
35        def fn(a: Dict[complex, complex], key: complex) -> complex:
36            return a[key]
37
38        input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
39        self.checkScript(fn, (input, -4.3 - 2j))
40
41    def test_pickle(self):
42        class ComplexModule(torch.jit.ScriptModule):
43            def __init__(self) -> None:
44                super().__init__()
45                self.a = 3 + 5j
46                self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
47                self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
48
49            @torch.jit.script_method
50            def forward(self, b: int):
51                return b + 2j
52
53        loaded = self.getExportImportCopy(ComplexModule())
54        self.assertEqual(loaded.a, 3 + 5j)
55        self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
56        self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j})
57        self.assertEqual(loaded(2), 2 + 2j)
58
59    def test_complex_parse(self):
60        def fn(a: int, b: torch.Tensor, dim: int):
61            # verifies `emitValueToTensor()` 's behavior
62            b[dim] = 2.4 + 0.5j
63            return (3 * 2j) + a + 5j - 7.4j - 4
64
65        t1 = torch.tensor(1)
66        t2 = torch.tensor([0.4, 1.4j, 2.35])
67
68        self.checkScript(fn, (t1, t2, 2))
69
70    def test_complex_constants_and_ops(self):
71        vals = (
72            [0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
73            + [10.0**i for i in range(2)]
74            + [-(10.0**i) for i in range(2)]
75        )
76        complex_vals = tuple(complex(x, y) for x, y in product(vals, vals))
77
78        funcs_template = dedent(
79            """
80            def func(a: complex):
81                return cmath.{func_or_const}(a)
82            """
83        )
84
85        def checkCmath(func_name, funcs_template=funcs_template):
86            funcs_str = funcs_template.format(func_or_const=func_name)
87            scope = {}
88            execWrapper(funcs_str, globals(), scope)
89            cu = torch.jit.CompilationUnit(funcs_str)
90            f_script = cu.func
91            f = scope["func"]
92
93            if func_name in ["isinf", "isnan", "isfinite"]:
94                new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")])
95                final_vals = tuple(
96                    complex(x, y) for x, y in product(new_vals, new_vals)
97                )
98            else:
99                final_vals = complex_vals
100
101            for a in final_vals:
102                res_python = None
103                res_script = None
104                try:
105                    res_python = f(a)
106                except Exception as e:
107                    res_python = e
108                try:
109                    res_script = f_script(a)
110                except Exception as e:
111                    res_script = e
112
113                if res_python != res_script:
114                    if isinstance(res_python, Exception):
115                        continue
116
117                    msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
118                    self.assertEqual(res_python, res_script, msg=msg)
119
120        unary_ops = [
121            "log",
122            "log10",
123            "sqrt",
124            "exp",
125            "sin",
126            "cos",
127            "asin",
128            "acos",
129            "atan",
130            "sinh",
131            "cosh",
132            "tanh",
133            "asinh",
134            "acosh",
135            "atanh",
136            "phase",
137            "isinf",
138            "isnan",
139            "isfinite",
140        ]
141
142        # --- Unary ops ---
143        for op in unary_ops:
144            checkCmath(op)
145
146        def fn(x: complex):
147            return abs(x)
148
149        for val in complex_vals:
150            self.checkScript(fn, (val,))
151
152        def pow_complex_float(x: complex, y: float):
153            return pow(x, y)
154
155        def pow_float_complex(x: float, y: complex):
156            return pow(x, y)
157
158        self.checkScript(pow_float_complex, (2, 3j))
159        self.checkScript(pow_complex_float, (3j, 2))
160
161        def pow_complex_complex(x: complex, y: complex):
162            return pow(x, y)
163
164        for x, y in zip(complex_vals, complex_vals):
165            # Reference: https://github.com/pytorch/pytorch/issues/54622
166            if x == 0:
167                continue
168            self.checkScript(pow_complex_complex, (x, y))
169
170        if not IS_MACOS:
171            # --- Binary op ---
172            def rect_fn(x: float, y: float):
173                return cmath.rect(x, y)
174
175            for x, y in product(vals, vals):
176                self.checkScript(
177                    rect_fn,
178                    (
179                        x,
180                        y,
181                    ),
182                )
183
184        func_constants_template = dedent(
185            """
186            def func():
187                return cmath.{func_or_const}
188            """
189        )
190        float_consts = ["pi", "e", "tau", "inf", "nan"]
191        complex_consts = ["infj", "nanj"]
192        for x in float_consts + complex_consts:
193            checkCmath(x, funcs_template=func_constants_template)
194
195    def test_infj_nanj_pickle(self):
196        class ComplexModule(torch.jit.ScriptModule):
197            def __init__(self) -> None:
198                super().__init__()
199                self.a = 3 + 5j
200
201            @torch.jit.script_method
202            def forward(self, infj: int, nanj: int):
203                if infj == 2:
204                    return infj + cmath.infj
205                else:
206                    return nanj + cmath.nanj
207
208        loaded = self.getExportImportCopy(ComplexModule())
209        self.assertEqual(loaded(2, 3), 2 + cmath.infj)
210        self.assertEqual(loaded(3, 4), 4 + cmath.nanj)
211
212    def test_complex_constructor(self):
213        # Test all scalar types
214        def fn_int(real: int, img: int):
215            return complex(real, img)
216
217        self.checkScript(
218            fn_int,
219            (
220                0,
221                0,
222            ),
223        )
224        self.checkScript(
225            fn_int,
226            (
227                -1234,
228                0,
229            ),
230        )
231        self.checkScript(
232            fn_int,
233            (
234                0,
235                -1256,
236            ),
237        )
238        self.checkScript(
239            fn_int,
240            (
241                -167,
242                -1256,
243            ),
244        )
245
246        def fn_float(real: float, img: float):
247            return complex(real, img)
248
249        self.checkScript(
250            fn_float,
251            (
252                0.0,
253                0.0,
254            ),
255        )
256        self.checkScript(
257            fn_float,
258            (
259                -1234.78,
260                0,
261            ),
262        )
263        self.checkScript(
264            fn_float,
265            (
266                0,
267                56.18,
268            ),
269        )
270        self.checkScript(
271            fn_float,
272            (
273                -1.9,
274                -19.8,
275            ),
276        )
277
278        def fn_bool(real: bool, img: bool):
279            return complex(real, img)
280
281        self.checkScript(
282            fn_bool,
283            (
284                True,
285                True,
286            ),
287        )
288        self.checkScript(
289            fn_bool,
290            (
291                False,
292                False,
293            ),
294        )
295        self.checkScript(
296            fn_bool,
297            (
298                False,
299                True,
300            ),
301        )
302        self.checkScript(
303            fn_bool,
304            (
305                True,
306                False,
307            ),
308        )
309
310        def fn_bool_int(real: bool, img: int):
311            return complex(real, img)
312
313        self.checkScript(
314            fn_bool_int,
315            (
316                True,
317                0,
318            ),
319        )
320        self.checkScript(
321            fn_bool_int,
322            (
323                False,
324                0,
325            ),
326        )
327        self.checkScript(
328            fn_bool_int,
329            (
330                False,
331                -1,
332            ),
333        )
334        self.checkScript(
335            fn_bool_int,
336            (
337                True,
338                3,
339            ),
340        )
341
342        def fn_int_bool(real: int, img: bool):
343            return complex(real, img)
344
345        self.checkScript(
346            fn_int_bool,
347            (
348                0,
349                True,
350            ),
351        )
352        self.checkScript(
353            fn_int_bool,
354            (
355                0,
356                False,
357            ),
358        )
359        self.checkScript(
360            fn_int_bool,
361            (
362                -3,
363                True,
364            ),
365        )
366        self.checkScript(
367            fn_int_bool,
368            (
369                6,
370                False,
371            ),
372        )
373
374        def fn_bool_float(real: bool, img: float):
375            return complex(real, img)
376
377        self.checkScript(
378            fn_bool_float,
379            (
380                True,
381                0.0,
382            ),
383        )
384        self.checkScript(
385            fn_bool_float,
386            (
387                False,
388                0.0,
389            ),
390        )
391        self.checkScript(
392            fn_bool_float,
393            (
394                False,
395                -1.0,
396            ),
397        )
398        self.checkScript(
399            fn_bool_float,
400            (
401                True,
402                3.0,
403            ),
404        )
405
406        def fn_float_bool(real: float, img: bool):
407            return complex(real, img)
408
409        self.checkScript(
410            fn_float_bool,
411            (
412                0.0,
413                True,
414            ),
415        )
416        self.checkScript(
417            fn_float_bool,
418            (
419                0.0,
420                False,
421            ),
422        )
423        self.checkScript(
424            fn_float_bool,
425            (
426                -3.0,
427                True,
428            ),
429        )
430        self.checkScript(
431            fn_float_bool,
432            (
433                6.0,
434                False,
435            ),
436        )
437
438        def fn_float_int(real: float, img: int):
439            return complex(real, img)
440
441        self.checkScript(
442            fn_float_int,
443            (
444                0.0,
445                1,
446            ),
447        )
448        self.checkScript(
449            fn_float_int,
450            (
451                0.0,
452                -1,
453            ),
454        )
455        self.checkScript(
456            fn_float_int,
457            (
458                1.8,
459                -3,
460            ),
461        )
462        self.checkScript(
463            fn_float_int,
464            (
465                2.7,
466                8,
467            ),
468        )
469
470        def fn_int_float(real: int, img: float):
471            return complex(real, img)
472
473        self.checkScript(
474            fn_int_float,
475            (
476                1,
477                0.0,
478            ),
479        )
480        self.checkScript(
481            fn_int_float,
482            (
483                -1,
484                1.7,
485            ),
486        )
487        self.checkScript(
488            fn_int_float,
489            (
490                -3,
491                0.0,
492            ),
493        )
494        self.checkScript(
495            fn_int_float,
496            (
497                2,
498                -8.9,
499            ),
500        )
501
502    def test_torch_complex_constructor_with_tensor(self):
503        tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])]
504
505        def fn_tensor_float(real, img: float):
506            return complex(real, img)
507
508        def fn_tensor_int(real, img: int):
509            return complex(real, img)
510
511        def fn_tensor_bool(real, img: bool):
512            return complex(real, img)
513
514        def fn_float_tensor(real: float, img):
515            return complex(real, img)
516
517        def fn_int_tensor(real: int, img):
518            return complex(real, img)
519
520        def fn_bool_tensor(real: bool, img):
521            return complex(real, img)
522
523        for tensor in tensors:
524            self.checkScript(fn_tensor_float, (tensor, 1.2))
525            self.checkScript(fn_tensor_int, (tensor, 3))
526            self.checkScript(fn_tensor_bool, (tensor, True))
527
528            self.checkScript(fn_float_tensor, (1.2, tensor))
529            self.checkScript(fn_int_tensor, (3, tensor))
530            self.checkScript(fn_bool_tensor, (True, tensor))
531
532        def fn_tensor_tensor(real, img):
533            return complex(real, img) + complex(2)
534
535        for x, y in product(tensors, tensors):
536            self.checkScript(
537                fn_tensor_tensor,
538                (
539                    x,
540                    y,
541                ),
542            )
543
544    def test_comparison_ops(self):
545        def fn1(a: complex, b: complex):
546            return a == b
547
548        def fn2(a: complex, b: complex):
549            return a != b
550
551        def fn3(a: complex, b: float):
552            return a == b
553
554        def fn4(a: complex, b: float):
555            return a != b
556
557        x, y = 2 - 3j, 4j
558        self.checkScript(fn1, (x, x))
559        self.checkScript(fn1, (x, y))
560        self.checkScript(fn2, (x, x))
561        self.checkScript(fn2, (x, y))
562
563        x1, y1 = 1 + 0j, 1.0
564        self.checkScript(fn3, (x1, y1))
565        self.checkScript(fn4, (x1, y1))
566
567    def test_div(self):
568        def fn1(a: complex, b: complex):
569            return a / b
570
571        x, y = 2 - 3j, 4j
572        self.checkScript(fn1, (x, y))
573
574    def test_complex_list_sum(self):
575        def fn(x: List[complex]):
576            return sum(x)
577
578        self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),))
579
580    def test_tensor_attributes(self):
581        def tensor_real(x):
582            return x.real
583
584        def tensor_imag(x):
585            return x.imag
586
587        t = torch.randn(2, 3, dtype=torch.cdouble)
588        self.checkScript(tensor_real, (t,))
589        self.checkScript(tensor_imag, (t,))
590
591    def test_binary_op_complex_tensor(self):
592        def mul(x: complex, y: torch.Tensor):
593            return x * y
594
595        def add(x: complex, y: torch.Tensor):
596            return x + y
597
598        def eq(x: complex, y: torch.Tensor):
599            return x == y
600
601        def ne(x: complex, y: torch.Tensor):
602            return x != y
603
604        def sub(x: complex, y: torch.Tensor):
605            return x - y
606
607        def div(x: complex, y: torch.Tensor):
608            return x - y
609
610        ops = [mul, add, eq, ne, sub, div]
611
612        for shape in [(1,), (2, 2)]:
613            x = 0.71 + 0.71j
614            y = torch.randn(shape, dtype=torch.cfloat)
615            for op in ops:
616                eager_result = op(x, y)
617                scripted = torch.jit.script(op)
618                jit_result = scripted(x, y)
619                self.assertEqual(eager_result, jit_result)
620