xref: /aosp_15_r20/external/pytorch/test/dynamo/test_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport functools
5*da0073e9SAndroid Build Coastguard Workerimport unittest
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torch
8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
11*da0073e9SAndroid Build Coastguard Workerfrom functorch.compile import nop
12*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo import compiled_autograd
13*da0073e9SAndroid Build Coastguard Workerfrom torch._functorch.aot_autograd import aot_module_simplified
14*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.hooks import RemovableHandle
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerdef compiler_fn(gm):
18*da0073e9SAndroid Build Coastguard Worker    return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm)
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerdef global_hook_0(grad):
22*da0073e9SAndroid Build Coastguard Worker    return grad * 4
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerdef global_hook_1(grad):
26*da0073e9SAndroid Build Coastguard Worker    return grad / 2
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard Workerdef global_hook_2(grad):
30*da0073e9SAndroid Build Coastguard Worker    return grad * 3
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerh0 = None
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Workerclass ClassWithVal:
37*da0073e9SAndroid Build Coastguard Worker    def __init__(self, val):
38*da0073e9SAndroid Build Coastguard Worker        self.val = val
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Workerclass HooksTests(torch._dynamo.test_case.TestCase):
42*da0073e9SAndroid Build Coastguard Worker    def test_tensor_only_register_hook_in_graph_lambda(self):
43*da0073e9SAndroid Build Coastguard Worker        def fn(x):
44*da0073e9SAndroid Build Coastguard Worker            x.register_hook(lambda grad: grad * 2)
45*da0073e9SAndroid Build Coastguard Worker            return x
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
48*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
49*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
50*da0073e9SAndroid Build Coastguard Worker        v = fn(v)
51*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
52*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
53*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 0)
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_hook_in_graph_lambda(self):
56*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
57*da0073e9SAndroid Build Coastguard Worker            x.register_hook(lambda grad: grad * 2)
58*da0073e9SAndroid Build Coastguard Worker            return x, y * y, z * z
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
61*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
62*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
63*da0073e9SAndroid Build Coastguard Worker        v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
64*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
65*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
66*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_hook_in_graph_break_handle_lambda(self):
69*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
70*da0073e9SAndroid Build Coastguard Worker            handle = x.register_hook(lambda grad: grad * 2)
71*da0073e9SAndroid Build Coastguard Worker            z = z * z
72*da0073e9SAndroid Build Coastguard Worker            handle.remove()
73*da0073e9SAndroid Build Coastguard Worker            x.register_hook(lambda grad: grad * 3)
74*da0073e9SAndroid Build Coastguard Worker            return x, y * y, z
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
77*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
78*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
79*da0073e9SAndroid Build Coastguard Worker        v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
80*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
81*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
82*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_hook_multi_handle_return(self):
85*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
86*da0073e9SAndroid Build Coastguard Worker            handle = x.register_hook(lambda grad: grad * 2)
87*da0073e9SAndroid Build Coastguard Worker            h2 = handle
88*da0073e9SAndroid Build Coastguard Worker            z = z * z
89*da0073e9SAndroid Build Coastguard Worker            return x, y * y, z, handle, h2
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
92*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
93*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
94*da0073e9SAndroid Build Coastguard Worker        v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
95*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
96*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
97*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
98*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(h, None)
99*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(h2, None)
100*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(h2, h)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_hook_repeated_handle_return(self):
103*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
104*da0073e9SAndroid Build Coastguard Worker            handle = x.register_hook(lambda grad: grad * 2)
105*da0073e9SAndroid Build Coastguard Worker            h2 = handle
106*da0073e9SAndroid Build Coastguard Worker            z = z * z
107*da0073e9SAndroid Build Coastguard Worker            return x, y * y, z, handle, handle
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
110*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
111*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
112*da0073e9SAndroid Build Coastguard Worker        v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
113*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
114*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
115*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
116*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(h, RemovableHandle)
117*da0073e9SAndroid Build Coastguard Worker        self.assertIs(h2, h)
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def test_removed_handle_return(self):
120*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend=cnt, fullgraph=True)
123*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
124*da0073e9SAndroid Build Coastguard Worker            handle = x.register_hook(lambda grad: grad * 2)
125*da0073e9SAndroid Build Coastguard Worker            z = z * z
126*da0073e9SAndroid Build Coastguard Worker            handle.remove()
127*da0073e9SAndroid Build Coastguard Worker            handle.remove()
128*da0073e9SAndroid Build Coastguard Worker            return x, y * y, z, handle, handle
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
131*da0073e9SAndroid Build Coastguard Worker        v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
132*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
133*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([1.0, 2.0, 3.0]))
134*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnt.frame_count, 1)
135*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(h, RemovableHandle)
136*da0073e9SAndroid Build Coastguard Worker        self.assertIs(h2, h)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_hook_repeated_handle_not_local(self):
139*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z, mod):
140*da0073e9SAndroid Build Coastguard Worker            mod.handle = x.register_hook(lambda grad: grad * 2)
141*da0073e9SAndroid Build Coastguard Worker            z = z * z
142*da0073e9SAndroid Build Coastguard Worker            return x, y * y, z
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
145*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
146*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        mod = torch.nn.Module()
149*da0073e9SAndroid Build Coastguard Worker        mod.handle = None
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod)
152*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
155*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(mod.handle, None)
158*da0073e9SAndroid Build Coastguard Worker
159*da0073e9SAndroid Build Coastguard Worker    def test_tensor_only_register_hook_in_graph_local(self):
160*da0073e9SAndroid Build Coastguard Worker        def local_hook(grad):
161*da0073e9SAndroid Build Coastguard Worker            return grad * 2
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker        def fn(x):
164*da0073e9SAndroid Build Coastguard Worker            x.register_hook(local_hook)
165*da0073e9SAndroid Build Coastguard Worker            return x
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
168*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
169*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
170*da0073e9SAndroid Build Coastguard Worker        v = fn(v)
171*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
172*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
173*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 0)
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker    def test_tensor_only_register_hook_in_graph_local_inner(self):
176*da0073e9SAndroid Build Coastguard Worker        def fn(x):
177*da0073e9SAndroid Build Coastguard Worker            def local_hook(grad):
178*da0073e9SAndroid Build Coastguard Worker                return grad * 2
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker            z = x * x
181*da0073e9SAndroid Build Coastguard Worker            x.register_hook(local_hook)
182*da0073e9SAndroid Build Coastguard Worker            z.register_hook(local_hook)
183*da0073e9SAndroid Build Coastguard Worker            return x, z
184*da0073e9SAndroid Build Coastguard Worker
185*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
186*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
187*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
188*da0073e9SAndroid Build Coastguard Worker        v = fn(v)
189*da0073e9SAndroid Build Coastguard Worker        v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0]))
191*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_hook_in_graph_local(self):
194*da0073e9SAndroid Build Coastguard Worker        def local_hook(grad):
195*da0073e9SAndroid Build Coastguard Worker            return grad * 2
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
198*da0073e9SAndroid Build Coastguard Worker            x.register_hook(local_hook)
199*da0073e9SAndroid Build Coastguard Worker            return x, y * y, z * z
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
202*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
203*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
204*da0073e9SAndroid Build Coastguard Worker        v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
205*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
206*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
207*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_hook_in_graph_break_handle_local(self):
210*da0073e9SAndroid Build Coastguard Worker        def local_hook(grad):
211*da0073e9SAndroid Build Coastguard Worker            return grad * 2
212*da0073e9SAndroid Build Coastguard Worker
213*da0073e9SAndroid Build Coastguard Worker        def local_hook2(grad):
214*da0073e9SAndroid Build Coastguard Worker            return grad * 3
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        def fn(x, y, z):
217*da0073e9SAndroid Build Coastguard Worker            handle = x.register_hook(local_hook)
218*da0073e9SAndroid Build Coastguard Worker            z = z * z
219*da0073e9SAndroid Build Coastguard Worker            handle.remove()
220*da0073e9SAndroid Build Coastguard Worker            x.register_hook(local_hook2)
221*da0073e9SAndroid Build Coastguard Worker            return x, y * y, z
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
224*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
225*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
226*da0073e9SAndroid Build Coastguard Worker        v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
227*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
228*da0073e9SAndroid Build Coastguard Worker
229*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_global_hook(self):
232*da0073e9SAndroid Build Coastguard Worker        def fn(x):
233*da0073e9SAndroid Build Coastguard Worker            x.register_hook(global_hook_0)
234*da0073e9SAndroid Build Coastguard Worker            return x, x * x
235*da0073e9SAndroid Build Coastguard Worker
236*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
237*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
238*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
239*da0073e9SAndroid Build Coastguard Worker        v = fn(v)[0]
240*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
241*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
242*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_multiple_hooks(self):
245*da0073e9SAndroid Build Coastguard Worker        def fn(x):
246*da0073e9SAndroid Build Coastguard Worker            x.register_hook(global_hook_0)  # * 4
247*da0073e9SAndroid Build Coastguard Worker            x.register_hook(global_hook_1)  # / 2
248*da0073e9SAndroid Build Coastguard Worker            x.register_hook(global_hook_2)  # * 3
249*da0073e9SAndroid Build Coastguard Worker            return x, x * x
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
252*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
253*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
254*da0073e9SAndroid Build Coastguard Worker        v = fn(v)[0]
255*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
256*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
257*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_multiple_hooks_handles_in_list(self):
260*da0073e9SAndroid Build Coastguard Worker        def fn(x):
261*da0073e9SAndroid Build Coastguard Worker            h0 = x.register_hook(global_hook_0)  # * 4
262*da0073e9SAndroid Build Coastguard Worker            h1 = x.register_hook(global_hook_1)  # / 2
263*da0073e9SAndroid Build Coastguard Worker            h2 = x.register_hook(global_hook_2)  # * 3
264*da0073e9SAndroid Build Coastguard Worker            return x, x * x, h0, h1, h2
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
267*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
268*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
269*da0073e9SAndroid Build Coastguard Worker        v, r, handle_0, handle_1, handle_2 = fn(v)
270*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
271*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
272*da0073e9SAndroid Build Coastguard Worker        handle_0.remove()
273*da0073e9SAndroid Build Coastguard Worker        handle_1.remove()
274*da0073e9SAndroid Build Coastguard Worker        handle_2.remove()
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
277*da0073e9SAndroid Build Coastguard Worker        # Handles gone, grad is just applied as is
278*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0]))
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
281*da0073e9SAndroid Build Coastguard Worker
282*da0073e9SAndroid Build Coastguard Worker    def test_tensor_register_global_hooks_handles_in_list(self):
283*da0073e9SAndroid Build Coastguard Worker        def fn(x):
284*da0073e9SAndroid Build Coastguard Worker            global h0
285*da0073e9SAndroid Build Coastguard Worker            h0 = x.register_hook(global_hook_0)  # * 4
286*da0073e9SAndroid Build Coastguard Worker            return x, x * x
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
289*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts)(fn)
290*da0073e9SAndroid Build Coastguard Worker        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
291*da0073e9SAndroid Build Coastguard Worker        v, r = fn(v)
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(h0)
294*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
295*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
296*da0073e9SAndroid Build Coastguard Worker        h0.remove()
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker        v.backward(torch.tensor([1.0, 2.0, 3.0]))
299*da0073e9SAndroid Build Coastguard Worker        # Handles gone, grad is just applied as is
300*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0]))
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        # NYI!
303*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 0)
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker    def test_intermediary_hooks(self):
306*da0073e9SAndroid Build Coastguard Worker        # Graph breaks because compiled_autograd is not set
307*da0073e9SAndroid Build Coastguard Worker        def simple_hook(g):
308*da0073e9SAndroid Build Coastguard Worker            return g * 2
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        def f(x):
311*da0073e9SAndroid Build Coastguard Worker            y = x + 1
312*da0073e9SAndroid Build Coastguard Worker            y.register_hook(simple_hook)
313*da0073e9SAndroid Build Coastguard Worker            z = y + 1
314*da0073e9SAndroid Build Coastguard Worker            return z
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker        out = torch.randn(1, requires_grad=True)
317*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
318*da0073e9SAndroid Build Coastguard Worker        fn = torch._dynamo.optimize(cnts, nopython=False)(f)
319*da0073e9SAndroid Build Coastguard Worker        res = fn(out)
320*da0073e9SAndroid Build Coastguard Worker        res.backward()
321*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(res, f(out))
322*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 2)
323*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out.grad, torch.Tensor([2.0]))
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker    def test_intermediary_hooks_same_on_aot_eager(self):
326*da0073e9SAndroid Build Coastguard Worker        def my_hook(grad, *, k=0):
327*da0073e9SAndroid Build Coastguard Worker            return grad + k
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
330*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
331*da0073e9SAndroid Build Coastguard Worker                y = x.mul(2)
332*da0073e9SAndroid Build Coastguard Worker                hook1 = functools.partial(my_hook, k=3)
333*da0073e9SAndroid Build Coastguard Worker                hook2 = functools.partial(my_hook, k=4)
334*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook1)
335*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook2)
336*da0073e9SAndroid Build Coastguard Worker                z = y.mul(3)
337*da0073e9SAndroid Build Coastguard Worker                return (z,)
338*da0073e9SAndroid Build Coastguard Worker
339*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
340*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
341*da0073e9SAndroid Build Coastguard Worker        eager_out = mod(x0)
342*da0073e9SAndroid Build Coastguard Worker        eager_out[0].backward(torch.ones(4))
343*da0073e9SAndroid Build Coastguard Worker
344*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
345*da0073e9SAndroid Build Coastguard Worker        mod_compiled = aot_module_simplified(mod, (x1,), nop)
346*da0073e9SAndroid Build Coastguard Worker        aot_out = mod_compiled(x1)
347*da0073e9SAndroid Build Coastguard Worker        aot_out[0].backward(torch.ones(4))
348*da0073e9SAndroid Build Coastguard Worker
349*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(4, requires_grad=True)
350*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(compiler_fn):
351*da0073e9SAndroid Build Coastguard Worker            dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2)
352*da0073e9SAndroid Build Coastguard Worker            dynamo_out[0].backward(torch.ones(4))
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamo_out, aot_out)
355*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamo_out, eager_out)
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x0.grad, x1.grad)
358*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x0.grad, x2.grad)
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker    def test_input_hooks_same(self):
361*da0073e9SAndroid Build Coastguard Worker        backends = ["eager", "aot_eager", "inductor"]
362*da0073e9SAndroid Build Coastguard Worker        for backend in backends:
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker            def my_hook(grad, *, k=0):
365*da0073e9SAndroid Build Coastguard Worker                return grad + k
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker            hook = functools.partial(my_hook, k=3)
368*da0073e9SAndroid Build Coastguard Worker
369*da0073e9SAndroid Build Coastguard Worker            class MyMod(torch.nn.Module):
370*da0073e9SAndroid Build Coastguard Worker                def forward(self, x):
371*da0073e9SAndroid Build Coastguard Worker                    x.register_hook(hook)
372*da0073e9SAndroid Build Coastguard Worker                    y = x.mul(2)
373*da0073e9SAndroid Build Coastguard Worker                    z = y.mul(3)
374*da0073e9SAndroid Build Coastguard Worker                    return (z,)
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker            mod = MyMod()
377*da0073e9SAndroid Build Coastguard Worker            x0 = torch.ones(4, requires_grad=True)
378*da0073e9SAndroid Build Coastguard Worker            eager_out = mod(x0)
379*da0073e9SAndroid Build Coastguard Worker            eager_out[0].backward(torch.ones(4))
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker            x1 = torch.ones(4, requires_grad=True)
382*da0073e9SAndroid Build Coastguard Worker            mod_compiled = aot_module_simplified(mod, (x1,), nop)
383*da0073e9SAndroid Build Coastguard Worker            aot_out = mod_compiled(x1)
384*da0073e9SAndroid Build Coastguard Worker            aot_out[0].backward(torch.ones(4))
385*da0073e9SAndroid Build Coastguard Worker
386*da0073e9SAndroid Build Coastguard Worker            x2 = torch.ones(4, requires_grad=True)
387*da0073e9SAndroid Build Coastguard Worker            dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2)
388*da0073e9SAndroid Build Coastguard Worker            with compiled_autograd.enable(compiler_fn):
389*da0073e9SAndroid Build Coastguard Worker                dynamo_out[0].backward(torch.ones(4))
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dynamo_out, aot_out)
392*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(dynamo_out, eager_out)
393*da0073e9SAndroid Build Coastguard Worker
394*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x0.grad, x1.grad)
395*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x0.grad, x2.grad)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker    def test_intermediary_hooks_same_on_inductor(self):
398*da0073e9SAndroid Build Coastguard Worker        def my_hook(grad, *, k=0):
399*da0073e9SAndroid Build Coastguard Worker            return grad + k
400*da0073e9SAndroid Build Coastguard Worker
401*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
402*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
403*da0073e9SAndroid Build Coastguard Worker                y = x.mul(2)
404*da0073e9SAndroid Build Coastguard Worker                hook1 = functools.partial(my_hook, k=3)
405*da0073e9SAndroid Build Coastguard Worker                hook2 = functools.partial(my_hook, k=4)
406*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook1)
407*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook2)
408*da0073e9SAndroid Build Coastguard Worker                z = y.mul(3)
409*da0073e9SAndroid Build Coastguard Worker                return (z,)
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
412*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
413*da0073e9SAndroid Build Coastguard Worker        eager_out = mod(x0)
414*da0073e9SAndroid Build Coastguard Worker        eager_out[0].backward(torch.ones(4))
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
417*da0073e9SAndroid Build Coastguard Worker        mod_compiled = aot_module_simplified(mod, (x1,), nop)
418*da0073e9SAndroid Build Coastguard Worker        aot_out = mod_compiled(x1)
419*da0073e9SAndroid Build Coastguard Worker        aot_out[0].backward(torch.ones(4))
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(4, requires_grad=True)
422*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(compiler_fn):
423*da0073e9SAndroid Build Coastguard Worker            dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2)
424*da0073e9SAndroid Build Coastguard Worker            dynamo_out[0].backward(torch.ones(4))
425*da0073e9SAndroid Build Coastguard Worker
426*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamo_out, aot_out)
427*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamo_out, eager_out)
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x0.grad, x1.grad)
430*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x0.grad, x2.grad)
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker    def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor(self):
433*da0073e9SAndroid Build Coastguard Worker        class SomePyClass:
434*da0073e9SAndroid Build Coastguard Worker            count = 0
435*da0073e9SAndroid Build Coastguard Worker
436*da0073e9SAndroid Build Coastguard Worker            def do_stuff(self, grad):
437*da0073e9SAndroid Build Coastguard Worker                if self.count % 2 == 0:
438*da0073e9SAndroid Build Coastguard Worker                    r = grad * grad
439*da0073e9SAndroid Build Coastguard Worker                else:
440*da0073e9SAndroid Build Coastguard Worker                    r = grad + grad
441*da0073e9SAndroid Build Coastguard Worker                self.count += 1
442*da0073e9SAndroid Build Coastguard Worker                return r
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker        def complex_state_touching_hook(grad, *, obj):
445*da0073e9SAndroid Build Coastguard Worker            return obj.do_stuff(grad)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
448*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, obj):
449*da0073e9SAndroid Build Coastguard Worker                y = x.mul(2)
450*da0073e9SAndroid Build Coastguard Worker                hook1 = functools.partial(complex_state_touching_hook, obj=obj)
451*da0073e9SAndroid Build Coastguard Worker                hook2 = functools.partial(complex_state_touching_hook, obj=obj)
452*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook1)
453*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook2)
454*da0073e9SAndroid Build Coastguard Worker                z = y.mul(3)
455*da0073e9SAndroid Build Coastguard Worker                return (z,)
456*da0073e9SAndroid Build Coastguard Worker
457*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
458*da0073e9SAndroid Build Coastguard Worker        obj = SomePyClass()
459*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
460*da0073e9SAndroid Build Coastguard Worker        eager_out = mod(x0, obj)
461*da0073e9SAndroid Build Coastguard Worker        eager_out[0].backward(torch.ones(4))
462*da0073e9SAndroid Build Coastguard Worker
463*da0073e9SAndroid Build Coastguard Worker        # Eager 2
464*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj.count, 2)
465*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(4, requires_grad=True)
466*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(compiler_fn):
467*da0073e9SAndroid Build Coastguard Worker            dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
468*da0073e9SAndroid Build Coastguard Worker            dynamo_out[0].backward(torch.ones(4))
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(dynamo_out, eager_out)
471*da0073e9SAndroid Build Coastguard Worker
472*da0073e9SAndroid Build Coastguard Worker        # Eager 2 + compiled 2
473*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj.count, 4)
474*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x0.grad, x2.grad)
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker    def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor_with_graph_break(
477*da0073e9SAndroid Build Coastguard Worker        self,
478*da0073e9SAndroid Build Coastguard Worker    ):
479*da0073e9SAndroid Build Coastguard Worker        class SomePyClass:
480*da0073e9SAndroid Build Coastguard Worker            grad_as_str = "None"
481*da0073e9SAndroid Build Coastguard Worker            count = 0
482*da0073e9SAndroid Build Coastguard Worker
483*da0073e9SAndroid Build Coastguard Worker            def write_grad_as_str_and_do_stuff(self, grad):
484*da0073e9SAndroid Build Coastguard Worker                self.grad_as_str = str(grad)
485*da0073e9SAndroid Build Coastguard Worker                if self.count % 2 == 0:
486*da0073e9SAndroid Build Coastguard Worker                    r = grad * grad
487*da0073e9SAndroid Build Coastguard Worker                else:
488*da0073e9SAndroid Build Coastguard Worker                    r = grad + grad
489*da0073e9SAndroid Build Coastguard Worker                print("Break!")
490*da0073e9SAndroid Build Coastguard Worker                self.count += 1
491*da0073e9SAndroid Build Coastguard Worker                return r
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker        def complex_state_touching_hook(grad, *, obj):
494*da0073e9SAndroid Build Coastguard Worker            return obj.write_grad_as_str_and_do_stuff(grad)
495*da0073e9SAndroid Build Coastguard Worker
496*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
497*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, obj):
498*da0073e9SAndroid Build Coastguard Worker                y = x.mul(2)
499*da0073e9SAndroid Build Coastguard Worker                hook1 = functools.partial(complex_state_touching_hook, obj=obj)
500*da0073e9SAndroid Build Coastguard Worker                hook2 = functools.partial(complex_state_touching_hook, obj=obj)
501*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook1)
502*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook2)
503*da0073e9SAndroid Build Coastguard Worker                z = y.mul(3)
504*da0073e9SAndroid Build Coastguard Worker                return (z,)
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
507*da0073e9SAndroid Build Coastguard Worker        obj = SomePyClass()
508*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
509*da0073e9SAndroid Build Coastguard Worker        eager_out = mod(x0, obj)
510*da0073e9SAndroid Build Coastguard Worker        eager_out[0].backward(torch.ones(4))
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(4, requires_grad=True)
513*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(compiler_fn):
514*da0073e9SAndroid Build Coastguard Worker            dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
515*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
516*da0073e9SAndroid Build Coastguard Worker                dynamo_out[0].backward(torch.ones(4))
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(obj.count, 2)
519*da0073e9SAndroid Build Coastguard Worker
520*da0073e9SAndroid Build Coastguard Worker    def test_register_hook_partial_guarding(
521*da0073e9SAndroid Build Coastguard Worker        self,
522*da0073e9SAndroid Build Coastguard Worker    ):
523*da0073e9SAndroid Build Coastguard Worker        def some_hook(grad, *, obj):
524*da0073e9SAndroid Build Coastguard Worker            return grad + obj.val
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
527*da0073e9SAndroid Build Coastguard Worker            def forward(self, x, obj):
528*da0073e9SAndroid Build Coastguard Worker                y = x.mul(2)
529*da0073e9SAndroid Build Coastguard Worker                hook1 = functools.partial(some_hook, obj=obj)
530*da0073e9SAndroid Build Coastguard Worker                y.register_hook(hook1)
531*da0073e9SAndroid Build Coastguard Worker                z = y.mul(3)
532*da0073e9SAndroid Build Coastguard Worker                return (z,)
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
535*da0073e9SAndroid Build Coastguard Worker        obj1 = ClassWithVal(torch.tensor(88))
536*da0073e9SAndroid Build Coastguard Worker        obj2 = ClassWithVal(torch.tensor(99))
537*da0073e9SAndroid Build Coastguard Worker        obj3 = ClassWithVal(11)
538*da0073e9SAndroid Build Coastguard Worker        cnt = torch._dynamo.testing.CompileCounter()
539*da0073e9SAndroid Build Coastguard Worker
540*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
541*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(compiler_fn):
544*da0073e9SAndroid Build Coastguard Worker            torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1)
545*da0073e9SAndroid Build Coastguard Worker            torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1)
546*da0073e9SAndroid Build Coastguard Worker            torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2)
547*da0073e9SAndroid Build Coastguard Worker            torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3)
548*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt.frame_count, 1)
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    def test_hook_with_closure(self):
551*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
552*da0073e9SAndroid Build Coastguard Worker            y = x.sin()
553*da0073e9SAndroid Build Coastguard Worker            x.register_hook(lambda grad: grad + obj.val)
554*da0073e9SAndroid Build Coastguard Worker            z = y.sin()
555*da0073e9SAndroid Build Coastguard Worker            return z
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        cnt_fw = torch._dynamo.testing.CompileCounter()
558*da0073e9SAndroid Build Coastguard Worker        cnt_bw = torch._dynamo.testing.CompileCounter()
559*da0073e9SAndroid Build Coastguard Worker        opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
560*da0073e9SAndroid Build Coastguard Worker
561*da0073e9SAndroid Build Coastguard Worker        obj1 = ClassWithVal(torch.tensor(88))
562*da0073e9SAndroid Build Coastguard Worker        obj2 = ClassWithVal(torch.tensor(99))
563*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
564*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
565*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(4, requires_grad=True)
566*da0073e9SAndroid Build Coastguard Worker        x3 = torch.ones(4, requires_grad=True)
567*da0073e9SAndroid Build Coastguard Worker        fn(x0, obj1).sum().backward()
568*da0073e9SAndroid Build Coastguard Worker        fn(x1, obj2).sum().backward()
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(
571*da0073e9SAndroid Build Coastguard Worker            functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
572*da0073e9SAndroid Build Coastguard Worker        ):
573*da0073e9SAndroid Build Coastguard Worker            opt(x2, obj1).sum().backward()
574*da0073e9SAndroid Build Coastguard Worker            opt(x3, obj2).sum().backward()
575*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt_fw.frame_count, 1)
576*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt_bw.frame_count, 1)
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x0.grad, x2.grad)
579*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x3.grad)
580*da0073e9SAndroid Build Coastguard Worker
581*da0073e9SAndroid Build Coastguard Worker    def test_intermediate_hook_with_closure_eager(self):
582*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
583*da0073e9SAndroid Build Coastguard Worker            y = x.sin()
584*da0073e9SAndroid Build Coastguard Worker            y.register_hook(lambda grad: grad + obj.val)
585*da0073e9SAndroid Build Coastguard Worker            z = y.sin()
586*da0073e9SAndroid Build Coastguard Worker            return z
587*da0073e9SAndroid Build Coastguard Worker
588*da0073e9SAndroid Build Coastguard Worker        cnt_fw = torch._dynamo.testing.CompileCounter()
589*da0073e9SAndroid Build Coastguard Worker        cnt_bw = torch._dynamo.testing.CompileCounter()
590*da0073e9SAndroid Build Coastguard Worker        opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        obj1 = ClassWithVal(torch.tensor(88))
593*da0073e9SAndroid Build Coastguard Worker        obj2 = ClassWithVal(torch.tensor(99))
594*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
595*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
596*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(4, requires_grad=True)
597*da0073e9SAndroid Build Coastguard Worker        x3 = torch.ones(4, requires_grad=True)
598*da0073e9SAndroid Build Coastguard Worker        fn(x0, obj1).sum().backward()
599*da0073e9SAndroid Build Coastguard Worker        fn(x1, obj2).sum().backward()
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(
602*da0073e9SAndroid Build Coastguard Worker            functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
603*da0073e9SAndroid Build Coastguard Worker        ):
604*da0073e9SAndroid Build Coastguard Worker            opt(x2, obj1).sum().backward()
605*da0073e9SAndroid Build Coastguard Worker            opt(x3, obj2).sum().backward()
606*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt_fw.frame_count, 1)
607*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt_bw.frame_count, 1)
608*da0073e9SAndroid Build Coastguard Worker
609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x0.grad, x2.grad)
610*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x3.grad)
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker    def test_intermediate_hook_with_closure_aot(self):
613*da0073e9SAndroid Build Coastguard Worker        def fn(x, obj):
614*da0073e9SAndroid Build Coastguard Worker            y = x.sin()
615*da0073e9SAndroid Build Coastguard Worker            y.register_hook(lambda grad: grad + obj.val)
616*da0073e9SAndroid Build Coastguard Worker            z = y.sin()
617*da0073e9SAndroid Build Coastguard Worker            return z
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        cnt_bw = torch._dynamo.testing.CompileCounter()
620*da0073e9SAndroid Build Coastguard Worker        opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
621*da0073e9SAndroid Build Coastguard Worker
622*da0073e9SAndroid Build Coastguard Worker        obj1 = ClassWithVal(torch.tensor(88))
623*da0073e9SAndroid Build Coastguard Worker        obj2 = ClassWithVal(torch.tensor(99))
624*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
625*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
626*da0073e9SAndroid Build Coastguard Worker        x2 = torch.ones(4, requires_grad=True)
627*da0073e9SAndroid Build Coastguard Worker        x3 = torch.ones(4, requires_grad=True)
628*da0073e9SAndroid Build Coastguard Worker        fn(x0, obj1).sum().backward()
629*da0073e9SAndroid Build Coastguard Worker        fn(x1, obj2).sum().backward()
630*da0073e9SAndroid Build Coastguard Worker
631*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(
632*da0073e9SAndroid Build Coastguard Worker            functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
633*da0073e9SAndroid Build Coastguard Worker        ):
634*da0073e9SAndroid Build Coastguard Worker            opt(x2, obj1).sum().backward()
635*da0073e9SAndroid Build Coastguard Worker            opt(x3, obj2).sum().backward()
636*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnt_bw.frame_count, 1)
637*da0073e9SAndroid Build Coastguard Worker
638*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x0.grad, x2.grad)
639*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x1.grad, x3.grad)
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker    def test_no_recompile_on_hook_identity_change(self):
642*da0073e9SAndroid Build Coastguard Worker        def my_hook(grad, k=0):
643*da0073e9SAndroid Build Coastguard Worker            return grad + k
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker        def my_hook2(grad):
646*da0073e9SAndroid Build Coastguard Worker            return grad * 2
647*da0073e9SAndroid Build Coastguard Worker
648*da0073e9SAndroid Build Coastguard Worker        class MyMod(torch.nn.Module):
649*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
650*da0073e9SAndroid Build Coastguard Worker                y = x.mul(2)
651*da0073e9SAndroid Build Coastguard Worker                y.register_hook(my_hook)
652*da0073e9SAndroid Build Coastguard Worker                y.register_hook(my_hook)
653*da0073e9SAndroid Build Coastguard Worker                z = y.mul(3)
654*da0073e9SAndroid Build Coastguard Worker                return (z,)
655*da0073e9SAndroid Build Coastguard Worker
656*da0073e9SAndroid Build Coastguard Worker        mod = MyMod()
657*da0073e9SAndroid Build Coastguard Worker        x0 = torch.ones(4, requires_grad=True)
658*da0073e9SAndroid Build Coastguard Worker        eager_out = mod(x0)
659*da0073e9SAndroid Build Coastguard Worker        eager_out[0].backward(torch.ones(4))
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker        x1 = torch.ones(4, requires_grad=True)
662*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(compiler_fn):
663*da0073e9SAndroid Build Coastguard Worker            cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
664*da0073e9SAndroid Build Coastguard Worker            comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod)
665*da0073e9SAndroid Build Coastguard Worker            comp_out = comp_mod(x1)
666*da0073e9SAndroid Build Coastguard Worker            comp_out[0].backward(torch.ones(4))
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 1)
669*da0073e9SAndroid Build Coastguard Worker            my_hook = my_hook2  # noqa: F811
670*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x0.grad, x1.grad)
671*da0073e9SAndroid Build Coastguard Worker
672*da0073e9SAndroid Build Coastguard Worker            eager_out = mod(x0)
673*da0073e9SAndroid Build Coastguard Worker            eager_out[0].backward(torch.ones(4))
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker            comp_out = comp_mod(x1)
676*da0073e9SAndroid Build Coastguard Worker
677*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(cnts.frame_count, 1)
678*da0073e9SAndroid Build Coastguard Worker            comp_out[0].backward(torch.ones(4))
679*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x0.grad, x1.grad)
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker    def test_functools_arg_vary(self):
682*da0073e9SAndroid Build Coastguard Worker        def pre_hook(grad, *, k):
683*da0073e9SAndroid Build Coastguard Worker            return grad * k
684*da0073e9SAndroid Build Coastguard Worker
685*da0073e9SAndroid Build Coastguard Worker        hook = functools.partial(pre_hook, k=1)
686*da0073e9SAndroid Build Coastguard Worker
687*da0073e9SAndroid Build Coastguard Worker        @torch.compile(backend="eager", fullgraph=True)
688*da0073e9SAndroid Build Coastguard Worker        def h(x):
689*da0073e9SAndroid Build Coastguard Worker            y = x.mul(2)
690*da0073e9SAndroid Build Coastguard Worker            y.register_hook(hook)
691*da0073e9SAndroid Build Coastguard Worker            return y.mul(3)
692*da0073e9SAndroid Build Coastguard Worker
693*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(torch.compile(backend="eager", fullgraph=True)):
694*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(2, requires_grad=True)
695*da0073e9SAndroid Build Coastguard Worker            h(x).sum().backward()
696*da0073e9SAndroid Build Coastguard Worker            orig_grad = x.grad
697*da0073e9SAndroid Build Coastguard Worker            x.grad = None
698*da0073e9SAndroid Build Coastguard Worker
699*da0073e9SAndroid Build Coastguard Worker            hook = functools.partial(pre_hook, k=2)
700*da0073e9SAndroid Build Coastguard Worker            h(x).sum().backward()
701*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(orig_grad * 2, x.grad)
702*da0073e9SAndroid Build Coastguard Worker
703*da0073e9SAndroid Build Coastguard Worker    def test_post_acc_grad_hook(self):
704*da0073e9SAndroid Build Coastguard Worker        def hook(input_t):
705*da0073e9SAndroid Build Coastguard Worker            input_t.mul_(input_t.grad)
706*da0073e9SAndroid Build Coastguard Worker            input_t.grad.mul_(5)
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker        def reg_and_mul(x, y):
709*da0073e9SAndroid Build Coastguard Worker            x.register_post_accumulate_grad_hook(hook)
710*da0073e9SAndroid Build Coastguard Worker            return x * y
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker        cnts = None
713*da0073e9SAndroid Build Coastguard Worker
714*da0073e9SAndroid Build Coastguard Worker        def test_fn(fn):
715*da0073e9SAndroid Build Coastguard Worker            fn(x, y)
716*da0073e9SAndroid Build Coastguard Worker            b = torch.tensor([2.0, 2.0, 2.0], requires_grad=True)
717*da0073e9SAndroid Build Coastguard Worker            x.backward(b)
718*da0073e9SAndroid Build Coastguard Worker            if cnts:
719*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(cnts.frame_count, 1)
720*da0073e9SAndroid Build Coastguard Worker            # These same exact assertions run on both eager and compiled
721*da0073e9SAndroid Build Coastguard Worker            # X goes to x*2 becaue of mul_
722*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2)
723*da0073e9SAndroid Build Coastguard Worker            # This test proves grad aliasing works -
724*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.grad, b * 5)
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker        # Eager values
727*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True)
728*da0073e9SAndroid Build Coastguard Worker        y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
729*da0073e9SAndroid Build Coastguard Worker        test_fn(reg_and_mul)
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker        # Compiled
732*da0073e9SAndroid Build Coastguard Worker        for backend in ["eager", "aot_eager", "inductor"]:
733*da0073e9SAndroid Build Coastguard Worker            for compiled_bwd in [False, True]:
734*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.reset()
735*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True)
736*da0073e9SAndroid Build Coastguard Worker                y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
737*da0073e9SAndroid Build Coastguard Worker
738*da0073e9SAndroid Build Coastguard Worker                cnts = torch._dynamo.testing.CompileCounterWithBackend(backend)
739*da0073e9SAndroid Build Coastguard Worker                compiled_fn = torch._dynamo.optimize(cnts, nopython=True)(reg_and_mul)
740*da0073e9SAndroid Build Coastguard Worker
741*da0073e9SAndroid Build Coastguard Worker                compiled_bwd_ctx = (
742*da0073e9SAndroid Build Coastguard Worker                    compiled_autograd.enable(
743*da0073e9SAndroid Build Coastguard Worker                        torch.compile(backend=backend, fullgraph=True)
744*da0073e9SAndroid Build Coastguard Worker                    )
745*da0073e9SAndroid Build Coastguard Worker                    if compiled_bwd
746*da0073e9SAndroid Build Coastguard Worker                    else contextlib.nullcontext()
747*da0073e9SAndroid Build Coastguard Worker                )
748*da0073e9SAndroid Build Coastguard Worker                with compiled_bwd_ctx:
749*da0073e9SAndroid Build Coastguard Worker                    test_fn(compiled_fn)
750*da0073e9SAndroid Build Coastguard Worker
751*da0073e9SAndroid Build Coastguard Worker    def test_recompile(self):
752*da0073e9SAndroid Build Coastguard Worker        def hook(param):
753*da0073e9SAndroid Build Coastguard Worker            param.grad *= 2
754*da0073e9SAndroid Build Coastguard Worker
755*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(10)
756*da0073e9SAndroid Build Coastguard Worker        x.requires_grad = True
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker        def run(input):
759*da0073e9SAndroid Build Coastguard Worker            return x * input
760*da0073e9SAndroid Build Coastguard Worker
761*da0073e9SAndroid Build Coastguard Worker        x.register_post_accumulate_grad_hook(hook)
762*da0073e9SAndroid Build Coastguard Worker        with compiled_autograd.enable(compiler_fn):
763*da0073e9SAndroid Build Coastguard Worker            for i in range(5):
764*da0073e9SAndroid Build Coastguard Worker                with unittest.mock.patch(
765*da0073e9SAndroid Build Coastguard Worker                    "torch._dynamo.config.error_on_recompile", True
766*da0073e9SAndroid Build Coastguard Worker                ):
767*da0073e9SAndroid Build Coastguard Worker                    # Mimic optimizer.zero_grad() to clear the gradient
768*da0073e9SAndroid Build Coastguard Worker                    x.grad = None
769*da0073e9SAndroid Build Coastguard Worker                    run(i).sum().backward()
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
772*da0073e9SAndroid Build Coastguard Worker    def test_no_recompile_on_same_hook(self):
773*da0073e9SAndroid Build Coastguard Worker        cnts = torch._dynamo.testing.CompileCounter()
774*da0073e9SAndroid Build Coastguard Worker
775*da0073e9SAndroid Build Coastguard Worker        def fw_hook(inp):
776*da0073e9SAndroid Build Coastguard Worker            return (inp[0] + 1,)
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        class Mod(torch.nn.Module):
779*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
780*da0073e9SAndroid Build Coastguard Worker                super().__init__()
781*da0073e9SAndroid Build Coastguard Worker                self.layers = torch.nn.ModuleList()
782*da0073e9SAndroid Build Coastguard Worker                for i in range(10):
783*da0073e9SAndroid Build Coastguard Worker                    layer = torch.nn.Linear(16, 16)
784*da0073e9SAndroid Build Coastguard Worker                    layer.register_forward_pre_hook(lambda _, inp: fw_hook(inp))
785*da0073e9SAndroid Build Coastguard Worker                    layer = torch.compile(layer, backend=cnts)
786*da0073e9SAndroid Build Coastguard Worker                    self.layers.append(layer)
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker            def forward(self, x):
789*da0073e9SAndroid Build Coastguard Worker                for l in self.layers:
790*da0073e9SAndroid Build Coastguard Worker                    x = l(x)
791*da0073e9SAndroid Build Coastguard Worker                return x
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker        mod = Mod()
794*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(16, 16, requires_grad=True)
795*da0073e9SAndroid Build Coastguard Worker        mod(x)
796*da0073e9SAndroid Build Coastguard Worker
797*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cnts.frame_count, 1)
798*da0073e9SAndroid Build Coastguard Worker
799*da0073e9SAndroid Build Coastguard Worker
800*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
801*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker    run_tests()
804