xref: /aosp_15_r20/external/pytorch/test/jit/test_autodiff.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3from typing import List
4
5import torch
6from torch.testing._internal.common_utils import skipIfTorchDynamo
7from torch.testing._internal.jit_utils import JitTestCase
8
9
10@skipIfTorchDynamo()
11class TestAutodiffJit(JitTestCase):
12    def test_undefined_tensor_lists(self):
13        def fn(tensor_list: List[torch.Tensor], add_tensor):
14            cat = torch.cat(tensor_list, dim=1)
15            r = torch.sin(cat + add_tensor)
16            return r
17
18        fn_s = torch.jit.script(fn)
19
20        a = torch.rand((3, 6), requires_grad=True)
21        b = torch.rand((3, 10), requires_grad=True)
22        x = [a, b]
23        y = torch.rand((3, 16), requires_grad=True)
24
25        ret = fn_s(x, y)
26        ret.sum().backward()
27        ret = fn_s(x, y)
28        ret.sum().backward()
29
30        ret = fn_s(x, y)
31        s = ret.sum()
32
33        # backward_fn expects 2 inputs: (grad_output, current_grad_r)
34        # current_grad_r is provided because we need to add this contribution
35        # to grad_r when we return it.
36        backward_fn = s.grad_fn.next_functions[0][0]
37
38        # check behavior with defined tensor
39        grad_out = torch.rand((3, 16))
40        grad_inputs = backward_fn(grad_out, None)
41
42        # expect 3 tensors: grad_y, grad_a, grad_b
43        self.assertEqual(3, len(grad_inputs))
44        for x in grad_inputs:
45            self.assertTrue(isinstance(x, torch.Tensor))
46
47        # now test with undefined grad_out
48        grad_inputs = backward_fn(None, None)
49
50        # expect all of them to be None
51        self.assertEqual(3, len(grad_inputs))
52        for x in grad_inputs:
53            if x is not None:
54                self.assertEqual(0, torch.max(torch.abs(x)).item())
55
56    def test_requires_grad_outputs(self):
57        # outputs should require_grad only if eager outputs would require_grad.
58        def fn(a, b, c):
59            return a.relu() + b.relu(), c.relu()
60
61        a = torch.rand((10, 10), requires_grad=False)
62        b = torch.rand((10, 10), requires_grad=False)
63        c = torch.rand((10, 10), requires_grad=True)
64
65        fn_s = torch.jit.script(fn)
66
67        for i in range(4):
68            x, y = fn_s(a, b, c)
69            self.assertFalse(x.requires_grad)
70            self.assertTrue(y.requires_grad)
71
72    def test_requires_grad_outputs_profiled_twice(self):
73        # the value "r" is used twice, by gammaln and by entr, so it is profiled twice.
74        # So during autodiff graph formation the profile nodes are unmerged because
75        # they are aliasing. Then the DifferentiableGraph doesn't have a profile
76        # node on the output. The requires_grad info should then be added onto the
77        # output value (otherwise autodiff will make the output require_grad).
78        # Note: this relies on gammaln and entr not having autodiff implementations.
79        def fn(a, b, c):
80            r = a.relu().relu()
81            return torch.special.gammaln(r), torch.special.entr(r), c.cos().relu()
82
83        fn_s = torch.jit.script(fn)
84
85        a = torch.rand((10, 10), requires_grad=False)
86        b = torch.rand((10, 10), requires_grad=False)
87        c = torch.rand((10, 10), requires_grad=True)
88
89        for i in range(4):
90            x_s, y_s, z_s = fn_s(a, b, c)
91            x, y, z = fn(a, b, c)
92
93            self.assertEqual(x_s.requires_grad, x.requires_grad)
94            self.assertEqual(y_s.requires_grad, y.requires_grad)
95            self.assertEqual(z_s.requires_grad, z.requires_grad)
96
97    def test_requires_grad_outputs_side_effects(self):
98        # same as above, but also add a CallFunction in between.
99        @torch.jit.ignore
100        def python_fn(x):
101            return x.relu()
102
103        def fn(a, b, c):
104            r = a.relu().relu()
105            z = python_fn(r)
106            return torch.relu(r), torch.nn.functional.gelu(r), c.cos().relu()
107
108        fn_s = torch.jit.script(fn)
109
110        a = torch.rand((10, 10), requires_grad=False)
111        b = torch.rand((10, 10), requires_grad=False)
112        c = torch.rand((10, 10), requires_grad=True)
113
114        for i in range(4):
115            x_s, y_s, z_s = fn_s(a, b, c)
116            x, y, z = fn(a, b, c)
117
118            self.assertEqual(x_s.requires_grad, x.requires_grad)
119            self.assertEqual(y_s.requires_grad, y.requires_grad)
120            self.assertEqual(z_s.requires_grad, z.requires_grad)
121
122    def test_autodiff_requires_grad_nograd(self):
123        @torch.jit.ignore
124        def python_fn(x):
125            return x.relu()
126
127        def fn(a, b, c):
128            x = a.sin().relu()
129            y = python_fn(b)
130            with torch.no_grad():
131                z = x + c
132            return x, y, z
133
134        fn_s = torch.jit.script(fn)
135
136        a = torch.rand((10, 10), requires_grad=True)
137        b = torch.rand((10, 10), requires_grad=True)
138        c = torch.rand((10, 10), requires_grad=True)
139
140        for i in range(4):
141            x_s, y_s, z_s = fn_s(a, b, c)
142            x, y, z = fn(a, b, c)
143
144            self.assertEqual(x_s.requires_grad, x.requires_grad)
145            self.assertEqual(y_s.requires_grad, y.requires_grad)
146            self.assertEqual(z_s.requires_grad, z.requires_grad)
147