xref: /aosp_15_r20/external/pytorch/test/dynamo/test_hooks.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2
3import contextlib
4import functools
5import unittest
6
7import torch
8import torch._dynamo
9import torch._dynamo.test_case
10import torch._dynamo.testing
11from functorch.compile import nop
12from torch._dynamo import compiled_autograd
13from torch._functorch.aot_autograd import aot_module_simplified
14from torch.utils.hooks import RemovableHandle
15
16
17def compiler_fn(gm):
18    return torch._dynamo.optimize("inductor", nopython=True, dynamic=True)(gm)
19
20
21def global_hook_0(grad):
22    return grad * 4
23
24
25def global_hook_1(grad):
26    return grad / 2
27
28
29def global_hook_2(grad):
30    return grad * 3
31
32
33h0 = None
34
35
36class ClassWithVal:
37    def __init__(self, val):
38        self.val = val
39
40
41class HooksTests(torch._dynamo.test_case.TestCase):
42    def test_tensor_only_register_hook_in_graph_lambda(self):
43        def fn(x):
44            x.register_hook(lambda grad: grad * 2)
45            return x
46
47        cnts = torch._dynamo.testing.CompileCounter()
48        fn = torch._dynamo.optimize(cnts)(fn)
49        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
50        v = fn(v)
51        v.backward(torch.tensor([1.0, 2.0, 3.0]))
52        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
53        self.assertEqual(cnts.frame_count, 0)
54
55    def test_tensor_register_hook_in_graph_lambda(self):
56        def fn(x, y, z):
57            x.register_hook(lambda grad: grad * 2)
58            return x, y * y, z * z
59
60        cnts = torch._dynamo.testing.CompileCounter()
61        fn = torch._dynamo.optimize(cnts)(fn)
62        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
63        v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
64        v.backward(torch.tensor([1.0, 2.0, 3.0]))
65        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
66        self.assertEqual(cnts.frame_count, 1)
67
68    def test_tensor_register_hook_in_graph_break_handle_lambda(self):
69        def fn(x, y, z):
70            handle = x.register_hook(lambda grad: grad * 2)
71            z = z * z
72            handle.remove()
73            x.register_hook(lambda grad: grad * 3)
74            return x, y * y, z
75
76        cnts = torch._dynamo.testing.CompileCounter()
77        fn = torch._dynamo.optimize(cnts)(fn)
78        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
79        v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
80        v.backward(torch.tensor([1.0, 2.0, 3.0]))
81        self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
82        self.assertEqual(cnts.frame_count, 1)
83
84    def test_tensor_register_hook_multi_handle_return(self):
85        def fn(x, y, z):
86            handle = x.register_hook(lambda grad: grad * 2)
87            h2 = handle
88            z = z * z
89            return x, y * y, z, handle, h2
90
91        cnts = torch._dynamo.testing.CompileCounter()
92        fn = torch._dynamo.optimize(cnts)(fn)
93        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
94        v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
95        v.backward(torch.tensor([1.0, 2.0, 3.0]))
96        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
97        self.assertEqual(cnts.frame_count, 1)
98        self.assertNotEqual(h, None)
99        self.assertNotEqual(h2, None)
100        self.assertEqual(h2, h)
101
102    def test_tensor_register_hook_repeated_handle_return(self):
103        def fn(x, y, z):
104            handle = x.register_hook(lambda grad: grad * 2)
105            h2 = handle
106            z = z * z
107            return x, y * y, z, handle, handle
108
109        cnts = torch._dynamo.testing.CompileCounter()
110        fn = torch._dynamo.optimize(cnts)(fn)
111        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
112        v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
113        v.backward(torch.tensor([1.0, 2.0, 3.0]))
114        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
115        self.assertEqual(cnts.frame_count, 1)
116        self.assertIsInstance(h, RemovableHandle)
117        self.assertIs(h2, h)
118
119    def test_removed_handle_return(self):
120        cnt = torch._dynamo.testing.CompileCounter()
121
122        @torch.compile(backend=cnt, fullgraph=True)
123        def fn(x, y, z):
124            handle = x.register_hook(lambda grad: grad * 2)
125            z = z * z
126            handle.remove()
127            handle.remove()
128            return x, y * y, z, handle, handle
129
130        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
131        v, y, z, h, h2 = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))
132        v.backward(torch.tensor([1.0, 2.0, 3.0]))
133        self.assertEqual(v.grad, torch.tensor([1.0, 2.0, 3.0]))
134        self.assertEqual(cnt.frame_count, 1)
135        self.assertIsInstance(h, RemovableHandle)
136        self.assertIs(h2, h)
137
138    def test_tensor_register_hook_repeated_handle_not_local(self):
139        def fn(x, y, z, mod):
140            mod.handle = x.register_hook(lambda grad: grad * 2)
141            z = z * z
142            return x, y * y, z
143
144        cnts = torch._dynamo.testing.CompileCounter()
145        fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
146        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
147
148        mod = torch.nn.Module()
149        mod.handle = None
150
151        v, y, z = fn(v, torch.randn([2, 2]), torch.randn([2, 2]), mod)
152        v.backward(torch.tensor([1.0, 2.0, 3.0]))
153
154        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
155        self.assertEqual(cnts.frame_count, 1)
156
157        self.assertNotEqual(mod.handle, None)
158
159    def test_tensor_only_register_hook_in_graph_local(self):
160        def local_hook(grad):
161            return grad * 2
162
163        def fn(x):
164            x.register_hook(local_hook)
165            return x
166
167        cnts = torch._dynamo.testing.CompileCounter()
168        fn = torch._dynamo.optimize(cnts)(fn)
169        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
170        v = fn(v)
171        v.backward(torch.tensor([1.0, 2.0, 3.0]))
172        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
173        self.assertEqual(cnts.frame_count, 0)
174
175    def test_tensor_only_register_hook_in_graph_local_inner(self):
176        def fn(x):
177            def local_hook(grad):
178                return grad * 2
179
180            z = x * x
181            x.register_hook(local_hook)
182            z.register_hook(local_hook)
183            return x, z
184
185        cnts = torch._dynamo.testing.CompileCounter()
186        fn = torch._dynamo.optimize(cnts)(fn)
187        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
188        v = fn(v)
189        v[0].backward(torch.tensor([1.0, 2.0, 3.0]))
190        self.assertEqual(v[0].grad, torch.tensor([2.0, 4.0, 6.0]))
191        self.assertEqual(cnts.frame_count, 1)
192
193    def test_tensor_register_hook_in_graph_local(self):
194        def local_hook(grad):
195            return grad * 2
196
197        def fn(x, y, z):
198            x.register_hook(local_hook)
199            return x, y * y, z * z
200
201        cnts = torch._dynamo.testing.CompileCounter()
202        fn = torch._dynamo.optimize(cnts)(fn)
203        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
204        v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
205        v.backward(torch.tensor([1.0, 2.0, 3.0]))
206        self.assertEqual(v.grad, torch.tensor([2.0, 4.0, 6.0]))
207        self.assertEqual(cnts.frame_count, 1)
208
209    def test_tensor_register_hook_in_graph_break_handle_local(self):
210        def local_hook(grad):
211            return grad * 2
212
213        def local_hook2(grad):
214            return grad * 3
215
216        def fn(x, y, z):
217            handle = x.register_hook(local_hook)
218            z = z * z
219            handle.remove()
220            x.register_hook(local_hook2)
221            return x, y * y, z
222
223        cnts = torch._dynamo.testing.CompileCounter()
224        fn = torch._dynamo.optimize(cnts)(fn)
225        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
226        v = fn(v, torch.randn([2, 2]), torch.randn([2, 2]))[0]
227        v.backward(torch.tensor([1.0, 2.0, 3.0]))
228
229        self.assertEqual(v.grad, torch.tensor([3.0, 6.0, 9.0]))
230
231    def test_tensor_register_global_hook(self):
232        def fn(x):
233            x.register_hook(global_hook_0)
234            return x, x * x
235
236        cnts = torch._dynamo.testing.CompileCounter()
237        fn = torch._dynamo.optimize(cnts)(fn)
238        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
239        v = fn(v)[0]
240        v.backward(torch.tensor([1.0, 2.0, 3.0]))
241        self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
242        self.assertEqual(cnts.frame_count, 1)
243
244    def test_tensor_register_multiple_hooks(self):
245        def fn(x):
246            x.register_hook(global_hook_0)  # * 4
247            x.register_hook(global_hook_1)  # / 2
248            x.register_hook(global_hook_2)  # * 3
249            return x, x * x
250
251        cnts = torch._dynamo.testing.CompileCounter()
252        fn = torch._dynamo.optimize(cnts)(fn)
253        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
254        v = fn(v)[0]
255        v.backward(torch.tensor([1.0, 2.0, 3.0]))
256        self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
257        self.assertEqual(cnts.frame_count, 1)
258
259    def test_tensor_register_multiple_hooks_handles_in_list(self):
260        def fn(x):
261            h0 = x.register_hook(global_hook_0)  # * 4
262            h1 = x.register_hook(global_hook_1)  # / 2
263            h2 = x.register_hook(global_hook_2)  # * 3
264            return x, x * x, h0, h1, h2
265
266        cnts = torch._dynamo.testing.CompileCounter()
267        fn = torch._dynamo.optimize(cnts)(fn)
268        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
269        v, r, handle_0, handle_1, handle_2 = fn(v)
270        v.backward(torch.tensor([1.0, 2.0, 3.0]))
271        self.assertEqual(v.grad, torch.tensor([6.0, 12.0, 18.0]))
272        handle_0.remove()
273        handle_1.remove()
274        handle_2.remove()
275
276        v.backward(torch.tensor([1.0, 2.0, 3.0]))
277        # Handles gone, grad is just applied as is
278        self.assertEqual(v.grad, torch.tensor([7.0, 14.0, 21.0]))
279
280        self.assertEqual(cnts.frame_count, 1)
281
282    def test_tensor_register_global_hooks_handles_in_list(self):
283        def fn(x):
284            global h0
285            h0 = x.register_hook(global_hook_0)  # * 4
286            return x, x * x
287
288        cnts = torch._dynamo.testing.CompileCounter()
289        fn = torch._dynamo.optimize(cnts)(fn)
290        v = torch.tensor([0.0, 0.0, 0.0], requires_grad=True)
291        v, r = fn(v)
292
293        self.assertIsNotNone(h0)
294        v.backward(torch.tensor([1.0, 2.0, 3.0]))
295        self.assertEqual(v.grad, torch.tensor([4.0, 8.0, 12.0]))
296        h0.remove()
297
298        v.backward(torch.tensor([1.0, 2.0, 3.0]))
299        # Handles gone, grad is just applied as is
300        self.assertEqual(v.grad, torch.tensor([5.0, 10.0, 15.0]))
301
302        # NYI!
303        self.assertEqual(cnts.frame_count, 0)
304
305    def test_intermediary_hooks(self):
306        # Graph breaks because compiled_autograd is not set
307        def simple_hook(g):
308            return g * 2
309
310        def f(x):
311            y = x + 1
312            y.register_hook(simple_hook)
313            z = y + 1
314            return z
315
316        out = torch.randn(1, requires_grad=True)
317        cnts = torch._dynamo.testing.CompileCounter()
318        fn = torch._dynamo.optimize(cnts, nopython=False)(f)
319        res = fn(out)
320        res.backward()
321        self.assertEqual(res, f(out))
322        self.assertEqual(cnts.frame_count, 2)
323        self.assertEqual(out.grad, torch.Tensor([2.0]))
324
325    def test_intermediary_hooks_same_on_aot_eager(self):
326        def my_hook(grad, *, k=0):
327            return grad + k
328
329        class MyMod(torch.nn.Module):
330            def forward(self, x):
331                y = x.mul(2)
332                hook1 = functools.partial(my_hook, k=3)
333                hook2 = functools.partial(my_hook, k=4)
334                y.register_hook(hook1)
335                y.register_hook(hook2)
336                z = y.mul(3)
337                return (z,)
338
339        mod = MyMod()
340        x0 = torch.ones(4, requires_grad=True)
341        eager_out = mod(x0)
342        eager_out[0].backward(torch.ones(4))
343
344        x1 = torch.ones(4, requires_grad=True)
345        mod_compiled = aot_module_simplified(mod, (x1,), nop)
346        aot_out = mod_compiled(x1)
347        aot_out[0].backward(torch.ones(4))
348
349        x2 = torch.ones(4, requires_grad=True)
350        with compiled_autograd.enable(compiler_fn):
351            dynamo_out = torch._dynamo.optimize("aot_eager", nopython=True)(mod)(x2)
352            dynamo_out[0].backward(torch.ones(4))
353
354        self.assertEqual(dynamo_out, aot_out)
355        self.assertEqual(dynamo_out, eager_out)
356
357        self.assertEqual(x0.grad, x1.grad)
358        self.assertEqual(x0.grad, x2.grad)
359
360    def test_input_hooks_same(self):
361        backends = ["eager", "aot_eager", "inductor"]
362        for backend in backends:
363
364            def my_hook(grad, *, k=0):
365                return grad + k
366
367            hook = functools.partial(my_hook, k=3)
368
369            class MyMod(torch.nn.Module):
370                def forward(self, x):
371                    x.register_hook(hook)
372                    y = x.mul(2)
373                    z = y.mul(3)
374                    return (z,)
375
376            mod = MyMod()
377            x0 = torch.ones(4, requires_grad=True)
378            eager_out = mod(x0)
379            eager_out[0].backward(torch.ones(4))
380
381            x1 = torch.ones(4, requires_grad=True)
382            mod_compiled = aot_module_simplified(mod, (x1,), nop)
383            aot_out = mod_compiled(x1)
384            aot_out[0].backward(torch.ones(4))
385
386            x2 = torch.ones(4, requires_grad=True)
387            dynamo_out = torch._dynamo.optimize(backend, nopython=True)(mod)(x2)
388            with compiled_autograd.enable(compiler_fn):
389                dynamo_out[0].backward(torch.ones(4))
390
391            self.assertEqual(dynamo_out, aot_out)
392            self.assertEqual(dynamo_out, eager_out)
393
394            self.assertEqual(x0.grad, x1.grad)
395            self.assertEqual(x0.grad, x2.grad)
396
397    def test_intermediary_hooks_same_on_inductor(self):
398        def my_hook(grad, *, k=0):
399            return grad + k
400
401        class MyMod(torch.nn.Module):
402            def forward(self, x):
403                y = x.mul(2)
404                hook1 = functools.partial(my_hook, k=3)
405                hook2 = functools.partial(my_hook, k=4)
406                y.register_hook(hook1)
407                y.register_hook(hook2)
408                z = y.mul(3)
409                return (z,)
410
411        mod = MyMod()
412        x0 = torch.ones(4, requires_grad=True)
413        eager_out = mod(x0)
414        eager_out[0].backward(torch.ones(4))
415
416        x1 = torch.ones(4, requires_grad=True)
417        mod_compiled = aot_module_simplified(mod, (x1,), nop)
418        aot_out = mod_compiled(x1)
419        aot_out[0].backward(torch.ones(4))
420
421        x2 = torch.ones(4, requires_grad=True)
422        with compiled_autograd.enable(compiler_fn):
423            dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2)
424            dynamo_out[0].backward(torch.ones(4))
425
426        self.assertEqual(dynamo_out, aot_out)
427        self.assertEqual(dynamo_out, eager_out)
428
429        self.assertEqual(x0.grad, x1.grad)
430        self.assertEqual(x0.grad, x2.grad)
431
432    def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor(self):
433        class SomePyClass:
434            count = 0
435
436            def do_stuff(self, grad):
437                if self.count % 2 == 0:
438                    r = grad * grad
439                else:
440                    r = grad + grad
441                self.count += 1
442                return r
443
444        def complex_state_touching_hook(grad, *, obj):
445            return obj.do_stuff(grad)
446
447        class MyMod(torch.nn.Module):
448            def forward(self, x, obj):
449                y = x.mul(2)
450                hook1 = functools.partial(complex_state_touching_hook, obj=obj)
451                hook2 = functools.partial(complex_state_touching_hook, obj=obj)
452                y.register_hook(hook1)
453                y.register_hook(hook2)
454                z = y.mul(3)
455                return (z,)
456
457        mod = MyMod()
458        obj = SomePyClass()
459        x0 = torch.ones(4, requires_grad=True)
460        eager_out = mod(x0, obj)
461        eager_out[0].backward(torch.ones(4))
462
463        # Eager 2
464        self.assertEqual(obj.count, 2)
465        x2 = torch.ones(4, requires_grad=True)
466        with compiled_autograd.enable(compiler_fn):
467            dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
468            dynamo_out[0].backward(torch.ones(4))
469
470        self.assertEqual(dynamo_out, eager_out)
471
472        # Eager 2 + compiled 2
473        self.assertEqual(obj.count, 4)
474        self.assertEqual(x0.grad, x2.grad)
475
476    def test_complex_state_mutation_in_intermediary_hooks_same_on_inductor_with_graph_break(
477        self,
478    ):
479        class SomePyClass:
480            grad_as_str = "None"
481            count = 0
482
483            def write_grad_as_str_and_do_stuff(self, grad):
484                self.grad_as_str = str(grad)
485                if self.count % 2 == 0:
486                    r = grad * grad
487                else:
488                    r = grad + grad
489                print("Break!")
490                self.count += 1
491                return r
492
493        def complex_state_touching_hook(grad, *, obj):
494            return obj.write_grad_as_str_and_do_stuff(grad)
495
496        class MyMod(torch.nn.Module):
497            def forward(self, x, obj):
498                y = x.mul(2)
499                hook1 = functools.partial(complex_state_touching_hook, obj=obj)
500                hook2 = functools.partial(complex_state_touching_hook, obj=obj)
501                y.register_hook(hook1)
502                y.register_hook(hook2)
503                z = y.mul(3)
504                return (z,)
505
506        mod = MyMod()
507        obj = SomePyClass()
508        x0 = torch.ones(4, requires_grad=True)
509        eager_out = mod(x0, obj)
510        eager_out[0].backward(torch.ones(4))
511
512        x2 = torch.ones(4, requires_grad=True)
513        with compiled_autograd.enable(compiler_fn):
514            dynamo_out = torch._dynamo.optimize("inductor", nopython=True)(mod)(x2, obj)
515            with self.assertRaisesRegex(torch._dynamo.exc.Unsupported, "builtin: str"):
516                dynamo_out[0].backward(torch.ones(4))
517
518        self.assertEqual(obj.count, 2)
519
520    def test_register_hook_partial_guarding(
521        self,
522    ):
523        def some_hook(grad, *, obj):
524            return grad + obj.val
525
526        class MyMod(torch.nn.Module):
527            def forward(self, x, obj):
528                y = x.mul(2)
529                hook1 = functools.partial(some_hook, obj=obj)
530                y.register_hook(hook1)
531                z = y.mul(3)
532                return (z,)
533
534        mod = MyMod()
535        obj1 = ClassWithVal(torch.tensor(88))
536        obj2 = ClassWithVal(torch.tensor(99))
537        obj3 = ClassWithVal(11)
538        cnt = torch._dynamo.testing.CompileCounter()
539
540        x0 = torch.ones(4, requires_grad=True)
541        x1 = torch.ones(4, requires_grad=True)
542
543        with compiled_autograd.enable(compiler_fn):
544            torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj1)
545            torch.compile(mod, backend=cnt, fullgraph=True)(x1, obj1)
546            torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj2)
547            torch.compile(mod, backend=cnt, fullgraph=True)(x0, obj3)
548            self.assertEqual(cnt.frame_count, 1)
549
550    def test_hook_with_closure(self):
551        def fn(x, obj):
552            y = x.sin()
553            x.register_hook(lambda grad: grad + obj.val)
554            z = y.sin()
555            return z
556
557        cnt_fw = torch._dynamo.testing.CompileCounter()
558        cnt_bw = torch._dynamo.testing.CompileCounter()
559        opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
560
561        obj1 = ClassWithVal(torch.tensor(88))
562        obj2 = ClassWithVal(torch.tensor(99))
563        x0 = torch.ones(4, requires_grad=True)
564        x1 = torch.ones(4, requires_grad=True)
565        x2 = torch.ones(4, requires_grad=True)
566        x3 = torch.ones(4, requires_grad=True)
567        fn(x0, obj1).sum().backward()
568        fn(x1, obj2).sum().backward()
569
570        with compiled_autograd.enable(
571            functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
572        ):
573            opt(x2, obj1).sum().backward()
574            opt(x3, obj2).sum().backward()
575            self.assertEqual(cnt_fw.frame_count, 1)
576            self.assertEqual(cnt_bw.frame_count, 1)
577
578        self.assertEqual(x0.grad, x2.grad)
579        self.assertEqual(x1.grad, x3.grad)
580
581    def test_intermediate_hook_with_closure_eager(self):
582        def fn(x, obj):
583            y = x.sin()
584            y.register_hook(lambda grad: grad + obj.val)
585            z = y.sin()
586            return z
587
588        cnt_fw = torch._dynamo.testing.CompileCounter()
589        cnt_bw = torch._dynamo.testing.CompileCounter()
590        opt = torch.compile(fn, backend=cnt_fw, fullgraph=True)
591
592        obj1 = ClassWithVal(torch.tensor(88))
593        obj2 = ClassWithVal(torch.tensor(99))
594        x0 = torch.ones(4, requires_grad=True)
595        x1 = torch.ones(4, requires_grad=True)
596        x2 = torch.ones(4, requires_grad=True)
597        x3 = torch.ones(4, requires_grad=True)
598        fn(x0, obj1).sum().backward()
599        fn(x1, obj2).sum().backward()
600
601        with compiled_autograd.enable(
602            functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
603        ):
604            opt(x2, obj1).sum().backward()
605            opt(x3, obj2).sum().backward()
606            self.assertEqual(cnt_fw.frame_count, 1)
607            self.assertEqual(cnt_bw.frame_count, 1)
608
609        self.assertEqual(x0.grad, x2.grad)
610        self.assertEqual(x1.grad, x3.grad)
611
612    def test_intermediate_hook_with_closure_aot(self):
613        def fn(x, obj):
614            y = x.sin()
615            y.register_hook(lambda grad: grad + obj.val)
616            z = y.sin()
617            return z
618
619        cnt_bw = torch._dynamo.testing.CompileCounter()
620        opt = torch.compile(fn, backend="aot_eager", fullgraph=True)
621
622        obj1 = ClassWithVal(torch.tensor(88))
623        obj2 = ClassWithVal(torch.tensor(99))
624        x0 = torch.ones(4, requires_grad=True)
625        x1 = torch.ones(4, requires_grad=True)
626        x2 = torch.ones(4, requires_grad=True)
627        x3 = torch.ones(4, requires_grad=True)
628        fn(x0, obj1).sum().backward()
629        fn(x1, obj2).sum().backward()
630
631        with compiled_autograd.enable(
632            functools.partial(torch.compile, backend=cnt_bw, fullgraph=True)
633        ):
634            opt(x2, obj1).sum().backward()
635            opt(x3, obj2).sum().backward()
636            self.assertEqual(cnt_bw.frame_count, 1)
637
638        self.assertEqual(x0.grad, x2.grad)
639        self.assertEqual(x1.grad, x3.grad)
640
641    def test_no_recompile_on_hook_identity_change(self):
642        def my_hook(grad, k=0):
643            return grad + k
644
645        def my_hook2(grad):
646            return grad * 2
647
648        class MyMod(torch.nn.Module):
649            def forward(self, x):
650                y = x.mul(2)
651                y.register_hook(my_hook)
652                y.register_hook(my_hook)
653                z = y.mul(3)
654                return (z,)
655
656        mod = MyMod()
657        x0 = torch.ones(4, requires_grad=True)
658        eager_out = mod(x0)
659        eager_out[0].backward(torch.ones(4))
660
661        x1 = torch.ones(4, requires_grad=True)
662        with compiled_autograd.enable(compiler_fn):
663            cnts = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
664            comp_mod = torch._dynamo.optimize(cnts, nopython=True)(mod)
665            comp_out = comp_mod(x1)
666            comp_out[0].backward(torch.ones(4))
667
668            self.assertEqual(cnts.frame_count, 1)
669            my_hook = my_hook2  # noqa: F811
670            self.assertEqual(x0.grad, x1.grad)
671
672            eager_out = mod(x0)
673            eager_out[0].backward(torch.ones(4))
674
675            comp_out = comp_mod(x1)
676
677            self.assertEqual(cnts.frame_count, 1)
678            comp_out[0].backward(torch.ones(4))
679            self.assertEqual(x0.grad, x1.grad)
680
681    def test_functools_arg_vary(self):
682        def pre_hook(grad, *, k):
683            return grad * k
684
685        hook = functools.partial(pre_hook, k=1)
686
687        @torch.compile(backend="eager", fullgraph=True)
688        def h(x):
689            y = x.mul(2)
690            y.register_hook(hook)
691            return y.mul(3)
692
693        with compiled_autograd.enable(torch.compile(backend="eager", fullgraph=True)):
694            x = torch.randn(2, requires_grad=True)
695            h(x).sum().backward()
696            orig_grad = x.grad
697            x.grad = None
698
699            hook = functools.partial(pre_hook, k=2)
700            h(x).sum().backward()
701            self.assertEqual(orig_grad * 2, x.grad)
702
703    def test_post_acc_grad_hook(self):
704        def hook(input_t):
705            input_t.mul_(input_t.grad)
706            input_t.grad.mul_(5)
707
708        def reg_and_mul(x, y):
709            x.register_post_accumulate_grad_hook(hook)
710            return x * y
711
712        cnts = None
713
714        def test_fn(fn):
715            fn(x, y)
716            b = torch.tensor([2.0, 2.0, 2.0], requires_grad=True)
717            x.backward(b)
718            if cnts:
719                self.assertEqual(cnts.frame_count, 1)
720            # These same exact assertions run on both eager and compiled
721            # X goes to x*2 becaue of mul_
722            self.assertEqual(x, torch.tensor([0.5, 0.5, 0.5]) * 2)
723            # This test proves grad aliasing works -
724            self.assertEqual(x.grad, b * 5)
725
726        # Eager values
727        x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True)
728        y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
729        test_fn(reg_and_mul)
730
731        # Compiled
732        for backend in ["eager", "aot_eager", "inductor"]:
733            for compiled_bwd in [False, True]:
734                torch._dynamo.reset()
735                x = torch.tensor([0.5, 0.5, 0.5], requires_grad=True)
736                y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
737
738                cnts = torch._dynamo.testing.CompileCounterWithBackend(backend)
739                compiled_fn = torch._dynamo.optimize(cnts, nopython=True)(reg_and_mul)
740
741                compiled_bwd_ctx = (
742                    compiled_autograd.enable(
743                        torch.compile(backend=backend, fullgraph=True)
744                    )
745                    if compiled_bwd
746                    else contextlib.nullcontext()
747                )
748                with compiled_bwd_ctx:
749                    test_fn(compiled_fn)
750
751    def test_recompile(self):
752        def hook(param):
753            param.grad *= 2
754
755        x = torch.ones(10)
756        x.requires_grad = True
757
758        def run(input):
759            return x * input
760
761        x.register_post_accumulate_grad_hook(hook)
762        with compiled_autograd.enable(compiler_fn):
763            for i in range(5):
764                with unittest.mock.patch(
765                    "torch._dynamo.config.error_on_recompile", True
766                ):
767                    # Mimic optimizer.zero_grad() to clear the gradient
768                    x.grad = None
769                    run(i).sum().backward()
770
771    @torch._dynamo.config.patch(inline_inbuilt_nn_modules=True)
772    def test_no_recompile_on_same_hook(self):
773        cnts = torch._dynamo.testing.CompileCounter()
774
775        def fw_hook(inp):
776            return (inp[0] + 1,)
777
778        class Mod(torch.nn.Module):
779            def __init__(self) -> None:
780                super().__init__()
781                self.layers = torch.nn.ModuleList()
782                for i in range(10):
783                    layer = torch.nn.Linear(16, 16)
784                    layer.register_forward_pre_hook(lambda _, inp: fw_hook(inp))
785                    layer = torch.compile(layer, backend=cnts)
786                    self.layers.append(layer)
787
788            def forward(self, x):
789                for l in self.layers:
790                    x = l(x)
791                return x
792
793        mod = Mod()
794        x = torch.ones(16, 16, requires_grad=True)
795        mod(x)
796
797        self.assertEqual(cnts.frame_count, 1)
798
799
800if __name__ == "__main__":
801    from torch._dynamo.test_case import run_tests
802
803    run_tests()
804