xref: /aosp_15_r20/external/pytorch/test/autograd/test_complex.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: autograd"]
2
3import torch
4from torch.testing._internal.common_utils import gradcheck, run_tests, TestCase
5
6
7class TestAutogradComplex(TestCase):
8    def test_view_func_for_complex_views(self):
9        # case 1: both parent and child have view_func
10        x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
11        y = x.detach().requires_grad_(True)
12
13        x0 = x.clone()
14        x1 = torch.view_as_complex(x0)
15        x2 = torch.view_as_real(x1)
16        x2.mul_(2)
17        x2.sum().abs().backward()
18
19        y0 = y.clone()
20        y0.mul_(2)
21        y0.sum().abs().backward()
22
23        self.assertEqual(x.grad, y.grad)
24
25        # case 2: parent has view_func but child does not
26        x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True)
27        y = x.detach().requires_grad_(True)
28
29        def fn(a):
30            b = a.clone()
31            b1 = torch.view_as_complex(b)
32            b2 = b1.reshape(b1.numel())
33            return b2
34
35        x0 = fn(x)
36        x0.mul_(2)
37        x0.sum().abs().backward()
38
39        y0 = fn(y)
40        y1 = y0.mul(2)
41        y1.sum().abs().backward()
42
43        self.assertEqual(x.grad, y.grad)
44
45        # case 3: parent does not have a view_func but child does
46        x = torch.randn(10, dtype=torch.cdouble, requires_grad=True)
47        y = x.detach().requires_grad_(True)
48
49        def fn(a, dim0_size=5):
50            b = a.clone()
51            b1 = b.reshape(dim0_size, 2)
52            b2 = torch.view_as_real(b1)
53            return b2
54
55        x0 = fn(x)
56        x0.mul_(2)
57        x0.sum().abs().backward()
58
59        y0 = fn(y)
60        y1 = y0.mul(2)
61        y1.sum().abs().backward()
62
63        self.assertEqual(x.grad, y.grad)
64
65    def test_view_with_multi_output(self):
66        x = torch.randn(2, 2, 2, dtype=torch.double)
67
68        x1 = torch.view_as_complex(x)
69        # Taking an invalid view should always be allowed as long as it is not
70        # modified inplace
71        res = x1.unbind(0)
72
73        with self.assertRaisesRegex(
74            RuntimeError, "output of a function that returns multiple views"
75        ):
76            res[0] += torch.rand(2, requires_grad=True)
77
78        x.requires_grad_(True)
79        x1 = torch.view_as_complex(x)
80        # Taking an invalid view should always be allowed as long as it is not
81        # modified inplace
82        res = x1.unbind(0)
83
84        with self.assertRaisesRegex(
85            RuntimeError, "output of a function that returns multiple views"
86        ):
87            res[0] += torch.rand(2, requires_grad=True)
88
89    def as_identity(self):
90        # view_as_real and view_as_complex behavior should be like an identity
91        def func(z):
92            z_ = torch.view_as_complex(z)
93            z_select = torch.select(z_, z_.dim() - 1, 0)
94            z_select_real = torch.view_as_real(z_select)
95            return z_select_real.sum()
96
97        z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True)
98        gradcheck(func, [z])
99        func(z).backward()
100
101        z1 = z.clone().detach().requires_grad_(True)
102        torch.select(z1, z1.dim() - 2, 0).sum().backward()
103
104        self.assertEqual(z.grad, z1.grad)
105
106
107if __name__ == "__main__":
108    run_tests()
109