1# Owner(s): ["oncall: jit"] 2 3from typing import List 4 5import torch 6from torch.testing._internal.common_utils import skipIfTorchDynamo 7from torch.testing._internal.jit_utils import JitTestCase 8 9 10@skipIfTorchDynamo() 11class TestAutodiffJit(JitTestCase): 12 def test_undefined_tensor_lists(self): 13 def fn(tensor_list: List[torch.Tensor], add_tensor): 14 cat = torch.cat(tensor_list, dim=1) 15 r = torch.sin(cat + add_tensor) 16 return r 17 18 fn_s = torch.jit.script(fn) 19 20 a = torch.rand((3, 6), requires_grad=True) 21 b = torch.rand((3, 10), requires_grad=True) 22 x = [a, b] 23 y = torch.rand((3, 16), requires_grad=True) 24 25 ret = fn_s(x, y) 26 ret.sum().backward() 27 ret = fn_s(x, y) 28 ret.sum().backward() 29 30 ret = fn_s(x, y) 31 s = ret.sum() 32 33 # backward_fn expects 2 inputs: (grad_output, current_grad_r) 34 # current_grad_r is provided because we need to add this contribution 35 # to grad_r when we return it. 36 backward_fn = s.grad_fn.next_functions[0][0] 37 38 # check behavior with defined tensor 39 grad_out = torch.rand((3, 16)) 40 grad_inputs = backward_fn(grad_out, None) 41 42 # expect 3 tensors: grad_y, grad_a, grad_b 43 self.assertEqual(3, len(grad_inputs)) 44 for x in grad_inputs: 45 self.assertTrue(isinstance(x, torch.Tensor)) 46 47 # now test with undefined grad_out 48 grad_inputs = backward_fn(None, None) 49 50 # expect all of them to be None 51 self.assertEqual(3, len(grad_inputs)) 52 for x in grad_inputs: 53 if x is not None: 54 self.assertEqual(0, torch.max(torch.abs(x)).item()) 55 56 def test_requires_grad_outputs(self): 57 # outputs should require_grad only if eager outputs would require_grad. 58 def fn(a, b, c): 59 return a.relu() + b.relu(), c.relu() 60 61 a = torch.rand((10, 10), requires_grad=False) 62 b = torch.rand((10, 10), requires_grad=False) 63 c = torch.rand((10, 10), requires_grad=True) 64 65 fn_s = torch.jit.script(fn) 66 67 for i in range(4): 68 x, y = fn_s(a, b, c) 69 self.assertFalse(x.requires_grad) 70 self.assertTrue(y.requires_grad) 71 72 def test_requires_grad_outputs_profiled_twice(self): 73 # the value "r" is used twice, by gammaln and by entr, so it is profiled twice. 74 # So during autodiff graph formation the profile nodes are unmerged because 75 # they are aliasing. Then the DifferentiableGraph doesn't have a profile 76 # node on the output. The requires_grad info should then be added onto the 77 # output value (otherwise autodiff will make the output require_grad). 78 # Note: this relies on gammaln and entr not having autodiff implementations. 79 def fn(a, b, c): 80 r = a.relu().relu() 81 return torch.special.gammaln(r), torch.special.entr(r), c.cos().relu() 82 83 fn_s = torch.jit.script(fn) 84 85 a = torch.rand((10, 10), requires_grad=False) 86 b = torch.rand((10, 10), requires_grad=False) 87 c = torch.rand((10, 10), requires_grad=True) 88 89 for i in range(4): 90 x_s, y_s, z_s = fn_s(a, b, c) 91 x, y, z = fn(a, b, c) 92 93 self.assertEqual(x_s.requires_grad, x.requires_grad) 94 self.assertEqual(y_s.requires_grad, y.requires_grad) 95 self.assertEqual(z_s.requires_grad, z.requires_grad) 96 97 def test_requires_grad_outputs_side_effects(self): 98 # same as above, but also add a CallFunction in between. 99 @torch.jit.ignore 100 def python_fn(x): 101 return x.relu() 102 103 def fn(a, b, c): 104 r = a.relu().relu() 105 z = python_fn(r) 106 return torch.relu(r), torch.nn.functional.gelu(r), c.cos().relu() 107 108 fn_s = torch.jit.script(fn) 109 110 a = torch.rand((10, 10), requires_grad=False) 111 b = torch.rand((10, 10), requires_grad=False) 112 c = torch.rand((10, 10), requires_grad=True) 113 114 for i in range(4): 115 x_s, y_s, z_s = fn_s(a, b, c) 116 x, y, z = fn(a, b, c) 117 118 self.assertEqual(x_s.requires_grad, x.requires_grad) 119 self.assertEqual(y_s.requires_grad, y.requires_grad) 120 self.assertEqual(z_s.requires_grad, z.requires_grad) 121 122 def test_autodiff_requires_grad_nograd(self): 123 @torch.jit.ignore 124 def python_fn(x): 125 return x.relu() 126 127 def fn(a, b, c): 128 x = a.sin().relu() 129 y = python_fn(b) 130 with torch.no_grad(): 131 z = x + c 132 return x, y, z 133 134 fn_s = torch.jit.script(fn) 135 136 a = torch.rand((10, 10), requires_grad=True) 137 b = torch.rand((10, 10), requires_grad=True) 138 c = torch.rand((10, 10), requires_grad=True) 139 140 for i in range(4): 141 x_s, y_s, z_s = fn_s(a, b, c) 142 x, y, z = fn(a, b, c) 143 144 self.assertEqual(x_s.requires_grad, x.requires_grad) 145 self.assertEqual(y_s.requires_grad, y.requires_grad) 146 self.assertEqual(z_s.requires_grad, z.requires_grad) 147