xref: /aosp_15_r20/external/pytorch/test/test_complex.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2# Owner(s): ["module: complex"]
3
4import torch
5from torch.testing._internal.common_device_type import (
6    dtypes,
7    instantiate_device_type_tests,
8    onlyCPU,
9)
10from torch.testing._internal.common_dtype import complex_types
11from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase
12
13
14devices = (torch.device("cpu"), torch.device("cuda:0"))
15
16
17class TestComplexTensor(TestCase):
18    @dtypes(*complex_types())
19    def test_to_list(self, device, dtype):
20        # test that the complex float tensor has expected values and
21        # there's no garbage value in the resultant list
22        self.assertEqual(
23            torch.zeros((2, 2), device=device, dtype=dtype).tolist(),
24            [[0j, 0j], [0j, 0j]],
25        )
26
27    @dtypes(torch.float32, torch.float64, torch.float16)
28    def test_dtype_inference(self, device, dtype):
29        # issue: https://github.com/pytorch/pytorch/issues/36834
30        with set_default_dtype(dtype):
31            x = torch.tensor([3.0, 3.0 + 5.0j], device=device)
32        if dtype == torch.float16:
33            self.assertEqual(x.dtype, torch.chalf)
34        elif dtype == torch.float32:
35            self.assertEqual(x.dtype, torch.cfloat)
36        else:
37            self.assertEqual(x.dtype, torch.cdouble)
38
39    @dtypes(*complex_types())
40    def test_conj_copy(self, device, dtype):
41        # issue: https://github.com/pytorch/pytorch/issues/106051
42        x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype)
43        xc1 = torch.conj(x1)
44        x1.copy_(xc1)
45        self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
46
47    @dtypes(*complex_types())
48    def test_all(self, device, dtype):
49        # issue: https://github.com/pytorch/pytorch/issues/120875
50        x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
51        self.assertTrue(torch.all(x))
52
53    @dtypes(*complex_types())
54    def test_any(self, device, dtype):
55        # issue: https://github.com/pytorch/pytorch/issues/120875
56        x = torch.tensor(
57            [0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
58        )
59        self.assertFalse(torch.any(x))
60
61    @onlyCPU
62    @dtypes(*complex_types())
63    def test_eq(self, device, dtype):
64        "Test eq on complex types"
65        nan = float("nan")
66        # Non-vectorized operations
67        for a, b in (
68            (
69                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
70                torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype),
71            ),
72            (
73                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
74                torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype),
75            ),
76            (
77                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
78                torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype),
79            ),
80        ):
81            actual = torch.eq(a, b)
82            expected = torch.tensor([False], device=device, dtype=torch.bool)
83            self.assertEqual(
84                actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
85            )
86
87            actual = torch.eq(a, a)
88            expected = torch.tensor([True], device=device, dtype=torch.bool)
89            self.assertEqual(
90                actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
91            )
92
93            actual = torch.full_like(b, complex(2, 2))
94            torch.eq(a, b, out=actual)
95            expected = torch.tensor([complex(0)], device=device, dtype=dtype)
96            self.assertEqual(
97                actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
98            )
99
100            actual = torch.full_like(b, complex(2, 2))
101            torch.eq(a, a, out=actual)
102            expected = torch.tensor([complex(1)], device=device, dtype=dtype)
103            self.assertEqual(
104                actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
105            )
106
107        # Vectorized operations
108        for a, b in (
109            (
110                torch.tensor(
111                    [
112                        -0.0610 - 2.1172j,
113                        5.1576 + 5.4775j,
114                        complex(2.8871, nan),
115                        -6.6545 - 3.7655j,
116                        -2.7036 - 1.4470j,
117                        0.3712 + 7.989j,
118                        -0.0610 - 2.1172j,
119                        5.1576 + 5.4775j,
120                        complex(nan, -3.2650),
121                        -6.6545 - 3.7655j,
122                        -2.7036 - 1.4470j,
123                        0.3712 + 7.989j,
124                    ],
125                    device=device,
126                    dtype=dtype,
127                ),
128                torch.tensor(
129                    [
130                        -6.1278 - 8.5019j,
131                        0.5886 + 8.8816j,
132                        complex(2.8871, nan),
133                        6.3505 + 2.2683j,
134                        0.3712 + 7.9659j,
135                        0.3712 + 7.989j,
136                        -6.1278 - 2.1172j,
137                        5.1576 + 8.8816j,
138                        complex(nan, -3.2650),
139                        6.3505 + 2.2683j,
140                        0.3712 + 7.9659j,
141                        0.3712 + 7.989j,
142                    ],
143                    device=device,
144                    dtype=dtype,
145                ),
146            ),
147        ):
148            actual = torch.eq(a, b)
149            expected = torch.tensor(
150                [
151                    False,
152                    False,
153                    False,
154                    False,
155                    False,
156                    True,
157                    False,
158                    False,
159                    False,
160                    False,
161                    False,
162                    True,
163                ],
164                device=device,
165                dtype=torch.bool,
166            )
167            self.assertEqual(
168                actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
169            )
170
171            actual = torch.eq(a, a)
172            expected = torch.tensor(
173                [
174                    True,
175                    True,
176                    False,
177                    True,
178                    True,
179                    True,
180                    True,
181                    True,
182                    False,
183                    True,
184                    True,
185                    True,
186                ],
187                device=device,
188                dtype=torch.bool,
189            )
190            self.assertEqual(
191                actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
192            )
193
194            actual = torch.full_like(b, complex(2, 2))
195            torch.eq(a, b, out=actual)
196            expected = torch.tensor(
197                [
198                    complex(0),
199                    complex(0),
200                    complex(0),
201                    complex(0),
202                    complex(0),
203                    complex(1),
204                    complex(0),
205                    complex(0),
206                    complex(0),
207                    complex(0),
208                    complex(0),
209                    complex(1),
210                ],
211                device=device,
212                dtype=dtype,
213            )
214            self.assertEqual(
215                actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
216            )
217
218            actual = torch.full_like(b, complex(2, 2))
219            torch.eq(a, a, out=actual)
220            expected = torch.tensor(
221                [
222                    complex(1),
223                    complex(1),
224                    complex(0),
225                    complex(1),
226                    complex(1),
227                    complex(1),
228                    complex(1),
229                    complex(1),
230                    complex(0),
231                    complex(1),
232                    complex(1),
233                    complex(1),
234                ],
235                device=device,
236                dtype=dtype,
237            )
238            self.assertEqual(
239                actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
240            )
241
242    @onlyCPU
243    @dtypes(*complex_types())
244    def test_ne(self, device, dtype):
245        "Test ne on complex types"
246        nan = float("nan")
247        # Non-vectorized operations
248        for a, b in (
249            (
250                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
251                torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype),
252            ),
253            (
254                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
255                torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype),
256            ),
257            (
258                torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
259                torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype),
260            ),
261        ):
262            actual = torch.ne(a, b)
263            expected = torch.tensor([True], device=device, dtype=torch.bool)
264            self.assertEqual(
265                actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
266            )
267
268            actual = torch.ne(a, a)
269            expected = torch.tensor([False], device=device, dtype=torch.bool)
270            self.assertEqual(
271                actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
272            )
273
274            actual = torch.full_like(b, complex(2, 2))
275            torch.ne(a, b, out=actual)
276            expected = torch.tensor([complex(1)], device=device, dtype=dtype)
277            self.assertEqual(
278                actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
279            )
280
281            actual = torch.full_like(b, complex(2, 2))
282            torch.ne(a, a, out=actual)
283            expected = torch.tensor([complex(0)], device=device, dtype=dtype)
284            self.assertEqual(
285                actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
286            )
287
288        # Vectorized operations
289        for a, b in (
290            (
291                torch.tensor(
292                    [
293                        -0.0610 - 2.1172j,
294                        5.1576 + 5.4775j,
295                        complex(2.8871, nan),
296                        -6.6545 - 3.7655j,
297                        -2.7036 - 1.4470j,
298                        0.3712 + 7.989j,
299                        -0.0610 - 2.1172j,
300                        5.1576 + 5.4775j,
301                        complex(nan, -3.2650),
302                        -6.6545 - 3.7655j,
303                        -2.7036 - 1.4470j,
304                        0.3712 + 7.989j,
305                    ],
306                    device=device,
307                    dtype=dtype,
308                ),
309                torch.tensor(
310                    [
311                        -6.1278 - 8.5019j,
312                        0.5886 + 8.8816j,
313                        complex(2.8871, nan),
314                        6.3505 + 2.2683j,
315                        0.3712 + 7.9659j,
316                        0.3712 + 7.989j,
317                        -6.1278 - 2.1172j,
318                        5.1576 + 8.8816j,
319                        complex(nan, -3.2650),
320                        6.3505 + 2.2683j,
321                        0.3712 + 7.9659j,
322                        0.3712 + 7.989j,
323                    ],
324                    device=device,
325                    dtype=dtype,
326                ),
327            ),
328        ):
329            actual = torch.ne(a, b)
330            expected = torch.tensor(
331                [
332                    True,
333                    True,
334                    True,
335                    True,
336                    True,
337                    False,
338                    True,
339                    True,
340                    True,
341                    True,
342                    True,
343                    False,
344                ],
345                device=device,
346                dtype=torch.bool,
347            )
348            self.assertEqual(
349                actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
350            )
351
352            actual = torch.ne(a, a)
353            expected = torch.tensor(
354                [
355                    False,
356                    False,
357                    True,
358                    False,
359                    False,
360                    False,
361                    False,
362                    False,
363                    True,
364                    False,
365                    False,
366                    False,
367                ],
368                device=device,
369                dtype=torch.bool,
370            )
371            self.assertEqual(
372                actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
373            )
374
375            actual = torch.full_like(b, complex(2, 2))
376            torch.ne(a, b, out=actual)
377            expected = torch.tensor(
378                [
379                    complex(1),
380                    complex(1),
381                    complex(1),
382                    complex(1),
383                    complex(1),
384                    complex(0),
385                    complex(1),
386                    complex(1),
387                    complex(1),
388                    complex(1),
389                    complex(1),
390                    complex(0),
391                ],
392                device=device,
393                dtype=dtype,
394            )
395            self.assertEqual(
396                actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
397            )
398
399            actual = torch.full_like(b, complex(2, 2))
400            torch.ne(a, a, out=actual)
401            expected = torch.tensor(
402                [
403                    complex(0),
404                    complex(0),
405                    complex(1),
406                    complex(0),
407                    complex(0),
408                    complex(0),
409                    complex(0),
410                    complex(0),
411                    complex(1),
412                    complex(0),
413                    complex(0),
414                    complex(0),
415                ],
416                device=device,
417                dtype=dtype,
418            )
419            self.assertEqual(
420                actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
421            )
422
423
424instantiate_device_type_tests(TestComplexTensor, globals())
425
426if __name__ == "__main__":
427    TestCase._default_dtype_check_enabled = True
428    run_tests()
429