1# Owner(s): ["module: autograd"] 2 3import torch 4from torch.testing._internal.common_utils import gradcheck, run_tests, TestCase 5 6 7class TestAutogradComplex(TestCase): 8 def test_view_func_for_complex_views(self): 9 # case 1: both parent and child have view_func 10 x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) 11 y = x.detach().requires_grad_(True) 12 13 x0 = x.clone() 14 x1 = torch.view_as_complex(x0) 15 x2 = torch.view_as_real(x1) 16 x2.mul_(2) 17 x2.sum().abs().backward() 18 19 y0 = y.clone() 20 y0.mul_(2) 21 y0.sum().abs().backward() 22 23 self.assertEqual(x.grad, y.grad) 24 25 # case 2: parent has view_func but child does not 26 x = torch.randn(2, 2, 2, dtype=torch.double, requires_grad=True) 27 y = x.detach().requires_grad_(True) 28 29 def fn(a): 30 b = a.clone() 31 b1 = torch.view_as_complex(b) 32 b2 = b1.reshape(b1.numel()) 33 return b2 34 35 x0 = fn(x) 36 x0.mul_(2) 37 x0.sum().abs().backward() 38 39 y0 = fn(y) 40 y1 = y0.mul(2) 41 y1.sum().abs().backward() 42 43 self.assertEqual(x.grad, y.grad) 44 45 # case 3: parent does not have a view_func but child does 46 x = torch.randn(10, dtype=torch.cdouble, requires_grad=True) 47 y = x.detach().requires_grad_(True) 48 49 def fn(a, dim0_size=5): 50 b = a.clone() 51 b1 = b.reshape(dim0_size, 2) 52 b2 = torch.view_as_real(b1) 53 return b2 54 55 x0 = fn(x) 56 x0.mul_(2) 57 x0.sum().abs().backward() 58 59 y0 = fn(y) 60 y1 = y0.mul(2) 61 y1.sum().abs().backward() 62 63 self.assertEqual(x.grad, y.grad) 64 65 def test_view_with_multi_output(self): 66 x = torch.randn(2, 2, 2, dtype=torch.double) 67 68 x1 = torch.view_as_complex(x) 69 # Taking an invalid view should always be allowed as long as it is not 70 # modified inplace 71 res = x1.unbind(0) 72 73 with self.assertRaisesRegex( 74 RuntimeError, "output of a function that returns multiple views" 75 ): 76 res[0] += torch.rand(2, requires_grad=True) 77 78 x.requires_grad_(True) 79 x1 = torch.view_as_complex(x) 80 # Taking an invalid view should always be allowed as long as it is not 81 # modified inplace 82 res = x1.unbind(0) 83 84 with self.assertRaisesRegex( 85 RuntimeError, "output of a function that returns multiple views" 86 ): 87 res[0] += torch.rand(2, requires_grad=True) 88 89 def as_identity(self): 90 # view_as_real and view_as_complex behavior should be like an identity 91 def func(z): 92 z_ = torch.view_as_complex(z) 93 z_select = torch.select(z_, z_.dim() - 1, 0) 94 z_select_real = torch.view_as_real(z_select) 95 return z_select_real.sum() 96 97 z = torch.randn(10, 2, 2, dtype=torch.double, requires_grad=True) 98 gradcheck(func, [z]) 99 func(z).backward() 100 101 z1 = z.clone().detach().requires_grad_(True) 102 torch.select(z1, z1.dim() - 2, 0).sum().backward() 103 104 self.assertEqual(z.grad, z1.grad) 105 106 107if __name__ == "__main__": 108 run_tests() 109