xref: /aosp_15_r20/external/pytorch/test/dynamo/test_ctx_manager.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import unittest
3
4import torch
5import torch._dynamo.test_case
6import torch._dynamo.testing
7import torch.onnx.operators
8from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
9from torch.nn import functional as F
10from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
11from torch.testing._internal.common_utils import TEST_WITH_ROCM
12
13
14class CustomizedCtxManager:
15    def __init__(self, mode):
16        self.prev = torch.is_grad_enabled()
17        self.mode = mode
18
19    def __enter__(self):
20        torch._C._set_grad_enabled(self.mode)
21
22    def __exit__(self, exc_type, exc_value, traceback):
23        torch._C._set_grad_enabled(self.prev)
24
25
26class CustomizedCtxManagerWithGraphBreak(CustomizedCtxManager):
27    def __enter__(self):
28        torch._dynamo.graph_break()
29        super().__enter__()
30
31
32class CtxManagerTests(torch._dynamo.test_case.TestCase):
33    def test_no_grad(self):
34        def fn1(a, b):
35            x = a + 1
36            # redundant no_grad should get ignored
37            with torch.no_grad():
38                x = x + b
39            x = x + 2
40            return x
41
42        def fn2(a, b):
43            x = a + 1
44            with torch.set_grad_enabled(False):
45                x = x + b
46            x = x + 2
47            return x
48
49        def fn3(a, b):
50            x = a + 1
51            with torch.enable_grad():
52                x = x + b
53            x = x + 2
54            return x
55
56        def fn4(a, b):
57            x = a + 1
58            with torch.set_grad_enabled(True):
59                if torch.is_grad_enabled():
60                    x = x + b
61            x = x + 2
62            return x
63
64        with torch.no_grad():
65            torch._dynamo.testing.standard_test(
66                self, fn=fn1, nargs=2, expected_ops=3
67            )  # coalesced noop
68            torch._dynamo.testing.standard_test(
69                self, fn=fn2, nargs=2, expected_ops=3
70            )  # coalesced noop
71            torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
72            torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
73        with torch.enable_grad():
74            torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
75            torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
76            torch._dynamo.testing.standard_test(
77                self, fn=fn3, nargs=2, expected_ops=3
78            )  # coalesced noop
79            torch._dynamo.testing.standard_test(
80                self, fn=fn4, nargs=2, expected_ops=3
81            )  # coalesced noop
82
83    def test_grad_mode_guard(self):
84        def fn(a, b):
85            prev_grad = torch.is_grad_enabled()
86            torch.set_grad_enabled(False)
87            a = a + 1
88            a.tolist()  # graph break
89            ret = a + b
90            torch.set_grad_enabled(prev_grad)
91            return ret
92
93        a = torch.randn([3, 4])
94        b = torch.randn([3, 4])
95        cnts = torch._dynamo.testing.CompileCounter()
96        opt_fn = torch._dynamo.optimize(cnts)(fn)
97        for _ in range(10):
98            opt_fn(a, b)
99        self.assertEqual(cnts.frame_count, 2)
100
101    def test_nested_grad_mode_graph_break(self):
102        def fn(x):
103            before = torch.is_grad_enabled()
104            with torch.set_grad_enabled(False):
105                torch._dynamo.graph_break()
106                with torch.set_grad_enabled(True):
107                    x = torch.mul(x, 5)
108                    torch._dynamo.graph_break()
109                    x = torch.sqrt(x)
110                    assert torch.is_grad_enabled()
111                assert not torch.is_grad_enabled()
112            assert torch.is_grad_enabled() == before
113            return x
114
115        a = torch.randn([3, 4])
116        cnts = torch._dynamo.testing.CompileCounter()
117        opt_fn = torch._dynamo.optimize(cnts)(fn)
118
119        for _ in range(10):
120            opt_fn(a)
121        self.assertEqual(cnts.frame_count, 2)
122
123    def test_torch_profiler(self):
124        # wrap torch.profiler.* as NullContextVariable and do nothing
125        def fn(x):
126            y = x**2
127            with torch.profiler.profile():
128                y = y + 2
129                with torch.profiler.record_function("my_function"):
130                    z = y**3
131                    z.tolist()  # graph break
132                    z = z + 1
133            return z
134
135        x = torch.randn((2, 2), requires_grad=True)
136        ref = fn(x)
137        cnts = torch._dynamo.testing.CompileCounter()
138        opt_fn = torch._dynamo.optimize(cnts)(fn)
139        res = opt_fn(x)
140        self.assertTrue(same(ref, res))
141        self.assertEqual(cnts.frame_count, 2)
142
143    def test_autograd_profiler(self):
144        # wrap torch.autograd.profiler.* as NullContextVariable and do nothing
145        def fn(x):
146            y = x**2
147            with torch.autograd.profiler.profile():
148                y = y + 2
149                with torch.autograd.profiler.record_function("my_function"):
150                    z = y**3
151                    z.tolist()  # graph break
152                    z = z + 1
153            return z
154
155        x = torch.randn((2, 2), requires_grad=True)
156        ref = fn(x)
157        cnts = torch._dynamo.testing.CompileCounter()
158        opt_fn = torch._dynamo.optimize(cnts)(fn)
159        res = opt_fn(x)
160        self.assertTrue(same(ref, res))
161        self.assertEqual(cnts.frame_count, 2)
162
163    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
164    def test_cuda_stream_context_manager1(self):
165        def fn(x):
166            s = torch.cuda.Stream()
167            x = torch.mul(x, 5)
168            x = torch.add(x, 2)
169            current_stream = torch.cuda.current_stream()
170            s.wait_stream(current_stream)
171            with torch.cuda.stream(s):
172                x = torch.relu(x)
173            current_stream.wait_stream(s)
174            x = torch.add(x, 1)
175            x = torch.cos(x)
176            return x
177
178        x = torch.randn((2, 2), device="cuda")
179        ref = fn(x)
180        cnts = torch._dynamo.testing.CompileCounter()
181        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
182        res = opt_fn(x)
183        self.assertEqual(ref, res)
184        self.assertEqual(cnts.frame_count, 1)
185        self.assertEqual(cnts.op_count, 12)
186
187    @unittest.expectedFailure  # https://github.com/pytorch/pytorch/issues/118204
188    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
189    def test_cuda_stream_across_graph_break(self):
190        def fn(x):
191            s = torch.cuda.Stream()
192            x = torch.mul(x, 5)
193            x = torch.add(x, 2)
194
195            print("foo")
196
197            tcs = torch.cuda.stream(s)
198            current_stream = torch.cuda.current_stream()
199            s.wait_stream(current_stream)
200
201            with tcs:
202                x = torch.relu(x)
203
204            current_stream.wait_stream(s)
205            x = torch.add(x, 1)
206            x = torch.cos(x)
207            return x
208
209        x = torch.randn((2, 2), device="cuda")
210        ref = fn(x)
211        cnts = torch._dynamo.testing.CompileCounter()
212        opt_fn = torch._dynamo.optimize(cnts)(fn)
213        res = opt_fn(x)
214        self.assertEqual(ref, res)
215        self.assertEqual(cnts.frame_count, 2)
216        self.assertEqual(cnts.op_count, 9)
217
218    @unittest.expectedFailure  # https://github.com/pytorch/pytorch/issues/118204
219    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
220    def test_cuda_stream_context_manager2(self):
221        def fn(x, s):
222            x = torch.mul(x, 5)
223            x = torch.add(x, 2)
224
225            current_stream = torch.cuda.current_stream()
226            s.wait_stream(current_stream)
227
228            with torch.cuda.stream(s):
229                x = torch.relu(x)
230
231            current_stream.wait_stream(s)
232            with torch.cuda.stream(current_stream):
233                x = torch.relu(x)
234
235            s2 = torch.cuda.Stream()
236            s2.wait_stream(current_stream)
237            with torch.cuda.stream(s2):
238                x = torch.relu(x)
239
240            current_stream.wait_stream(s2)
241            x = torch.add(x, 1)
242            x = torch.cos(x)
243            return x
244
245        x = torch.randn((2, 2), device="cuda")
246        s = torch.cuda.Stream()
247        ref = fn(x, s)
248        cnts = torch._dynamo.testing.CompileCounter()
249        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
250        res = opt_fn(x, s)
251        self.assertEqual(ref, res)
252        self.assertEqual(cnts.frame_count, 1)
253        self.assertEqual(cnts.op_count, 18)
254
255    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
256    def test_cuda_stream_method(self):
257        def fn(x):
258            x = torch.mul(x, 1)
259            x = torch.add(x, 2)
260
261            new_stream = torch.cuda.Stream()
262            cur_stream = torch.cuda.current_stream()
263            new_stream.wait_stream(cur_stream)
264
265            with torch.cuda.stream(new_stream):
266                x = torch.sin(x)
267                x = torch.add(x, 3)
268
269            cur_stream.wait_stream(new_stream)
270
271            x = torch.add(x, 4)
272            is_idle = cur_stream.query()
273            cur_stream.synchronize()
274
275            with torch.cuda.stream(new_stream):
276                x = torch.add(x, 5)
277            new_stream.synchronize()
278
279            is_equal = cur_stream == new_stream
280
281            x = torch.relu(x)
282            x = torch.cos(x)
283            return x
284
285        x = torch.randn((2, 2), device="cuda")
286        ref = fn(x)
287        cnts = torch._dynamo.testing.CompileCounter()
288        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
289        res = opt_fn(x)
290        self.assertEqual(ref, res)
291        self.assertEqual(cnts.frame_count, 1)
292        self.assertEqual(cnts.op_count, 21)
293
294    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
295    def test_cuda_stream_compared_with_constant(self):
296        def fn(x):
297            x = torch.mul(x, 1)
298            x = torch.add(x, 2)
299
300            cur_stream = torch.cuda.current_stream()
301            if cur_stream is not None:
302                return x + 1
303            return x - 1
304
305        def fn2(x):
306            x = torch.mul(x, 1)
307            x = torch.add(x, 2)
308
309            cur_stream = torch.cuda.current_stream()
310            if cur_stream != "const_str":
311                return x + 1
312            return x - 1
313
314        x = torch.randn((2, 2), device="cuda")
315        ref = fn(x)
316        cnts = torch._dynamo.testing.CompileCounter()
317        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
318        opt_fn2 = torch._dynamo.optimize(cnts, nopython=True)(fn2)
319        res = opt_fn(x)
320        res2 = opt_fn2(x)
321        self.assertEqual(ref, res)
322        self.assertEqual(ref, res2)
323
324    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
325    def test_cuda_stream_compared_with_stream(self):
326        def fn(x, s0, s1):
327            if s0 == s1:
328                return x + 1
329            else:
330                return x - 1
331
332        s0 = torch.cuda.Stream()
333        s1 = torch.cuda.Stream()
334        x = torch.randn(2, 2)
335        cnts = torch._dynamo.testing.CompileCounter()
336        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
337
338        ref0 = fn(x, s0, s1)
339        res0 = opt_fn(x, s0, s1)
340        self.assertEqual(cnts.frame_count, 1)
341        self.assertEqual(ref0, res0)
342
343        ref1 = fn(x, s1, s1)
344        res1 = opt_fn(x, s1, s1)
345        # We have a re-compilation because of chaning inputs
346        self.assertEqual(cnts.frame_count, 2)
347        self.assertEqual(ref1, res1)
348
349        torch._dynamo.reset()
350        cnts = torch._dynamo.testing.CompileCounter()
351        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
352
353        ref1 = fn(x, s1, s1)
354        res1 = opt_fn(x, s1, s1)
355        self.assertEqual(cnts.frame_count, 1)
356        self.assertEqual(ref1, res1)
357
358        ref0 = fn(x, s0, s1)
359        res0 = opt_fn(x, s0, s1)
360        # We have a re-compilation because of chaning inputs
361        self.assertEqual(cnts.frame_count, 2)
362        self.assertEqual(ref0, res0)
363
364    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
365    def test_cuda_event_reconstruct(self):
366        def fn(x):
367            e = torch.cuda.Event()
368            x = torch.mul(x, 5)
369            x = torch.add(x, 2)
370            return x, e
371
372        x = torch.randn((2, 2), device="cuda")
373        ref = fn(x)
374        cnts = torch._dynamo.testing.CompileCounter()
375        opt_fn = torch._dynamo.optimize(cnts)(fn)
376        res = opt_fn(x)
377        self.assertEqual(ref[0], res[0])
378        self.assertEqual(cnts.frame_count, 1)
379        self.assertEqual(cnts.op_count, 3)
380
381    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
382    def test_cuda_event_across_graph_break(self):
383        def fn(x):
384            e = torch.cuda.Event()
385            e.record()
386            x = torch.mul(x, 5)
387            x = torch.add(x, 2)
388
389            print("foo")
390
391            torch.cuda.current_stream().wait_event(e)
392            x = torch.add(x, 1)
393            x = torch.cos(x)
394            return x, e
395
396        x = torch.randn((2, 2), device="cuda")
397        ref = fn(x)
398        cnts = torch._dynamo.testing.CompileCounter()
399        opt_fn = torch._dynamo.optimize(cnts)(fn)
400        res = opt_fn(x)
401        self.assertEqual(ref[0], res[0])
402        self.assertEqual(cnts.frame_count, 2)
403        self.assertEqual(cnts.op_count, 9)
404
405    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
406    def test_cuda_event_created_outside_of_graph(self):
407        user_stream = torch.cuda.Stream()
408        event = torch.cuda.Event()
409        foo = torch.empty((2, 2), device="cuda")
410
411        def func(foo):
412            event.wait()
413            return foo + 1, event
414
415        x = torch.randn((1024, 1024), device="cuda")
416        cnts = torch._dynamo.testing.CompileCounter()
417
418        def run_iters(fn, compile=False):
419            if compile:
420                fn = torch._dynamo.optimize(cnts)(fn)
421            for _ in range(10):
422                with torch.cuda.stream(user_stream):
423                    torch.mm(x, x, out=foo)
424                    event.record()
425                out = fn(foo)
426            return out
427
428        ref = run_iters(func, compile=False)
429        res = run_iters(func, compile=True)
430        self.assertEqual(ref, res)
431        self.assertEqual(cnts.frame_count, 1)
432        self.assertEqual(cnts.op_count, 3)
433
434    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
435    def test_cuda_event_method_create_stream_outside_of_compile(self):
436        def fn(x, cur_stream, new_stream):
437            x = torch.mul(x, 1)
438            x = torch.add(x, 2)
439
440            x = torch.add(x, 3)
441
442            event = cur_stream.record_event()
443            is_idle = event.query()
444
445            new_stream.wait_event(event)
446            with torch.cuda.stream(new_stream):
447                x = torch.add(x, 4)
448
449            new_event = torch.cuda.Event()
450            new_event.record(new_stream)
451
452            new_event.wait(cur_stream)
453            x = torch.add(x, 5)
454
455            # use new event to sync
456            new_event.synchronize()
457
458            x = torch.relu(x)
459            x = torch.cos(x)
460            return x
461
462        x = torch.randn((2, 2), device="cuda")
463        cur_stream = torch.cuda.current_stream()
464        new_stream = torch.cuda.Stream()
465        ref = fn(x, cur_stream, new_stream)
466        cnts = torch._dynamo.testing.CompileCounter()
467        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
468        res = opt_fn(x, cur_stream, new_stream)
469        self.assertEqual(ref, res)
470        self.assertEqual(cnts.frame_count, 1)
471        self.assertEqual(cnts.op_count, 19)
472
473    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
474    def test_cuda_event_method(self):
475        def fn(x):
476            x = torch.mul(x, 1)
477            x = torch.add(x, 2)
478
479            cur_stream = torch.cuda.current_stream()
480            new_stream = torch.cuda.Stream()
481
482            x = torch.add(x, 3)
483
484            event = cur_stream.record_event()
485            is_idle = event.query()
486
487            new_stream.wait_event(event)
488            with torch.cuda.stream(new_stream):
489                x = torch.add(x, 4)
490
491            new_event = torch.cuda.Event()
492            new_event.record(new_stream)
493
494            new_event.wait(cur_stream)
495            x = torch.add(x, 5)
496
497            # use new event to sync
498            new_event.synchronize()
499
500            x = torch.relu(x)
501            x = torch.cos(x)
502            return x
503
504        x = torch.randn((2, 2), device="cuda")
505        ref = fn(x)
506        cnts = torch._dynamo.testing.CompileCounter()
507        opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
508        res = opt_fn(x)
509        self.assertEqual(ref, res)
510        self.assertEqual(cnts.frame_count, 1)
511        self.assertEqual(cnts.op_count, 19)
512
513    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
514    def test_cuda_device(self):
515        def fn(x):
516            with torch.cuda.device(x.device.index - 1):
517                x = torch.sin(x + 1)
518            return x
519
520        x = torch.randn((2, 2), device="cuda")
521        ref = fn(x)
522        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
523        res = opt_fn(x)
524        self.assertEqual(ref, res)
525
526    def test_autograd_profiler_enabled(self):
527        def fn(x):
528            if torch.autograd._profiler_enabled():
529                return x + 1
530            else:
531                return x - 1
532
533        x = torch.randn((2, 2), requires_grad=True)
534        cnts = torch._dynamo.testing.CompileCounter()
535        opt_fn = torch._dynamo.optimize(cnts)(fn)
536
537        if torch.autograd._profiler_enabled():
538            torch.autograd._disable_profiler()
539        assert not torch.autograd._profiler_enabled()
540        ref = fn(x)
541        res = opt_fn(x)
542        self.assertTrue(same(ref, res))
543
544        with torch.autograd.profiler.profile():
545            assert torch.autograd._profiler_enabled()
546            ref = fn(x)
547            res = opt_fn(x)
548            self.assertTrue(same(ref, res))
549
550    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
551    def test_autocast(self):
552        if not torch.cuda.is_bf16_supported():
553            raise unittest.SkipTest("requires bf16")
554
555        class MyModule(torch.nn.Module):
556            def forward(self, x):
557                a_float32 = torch.rand((8, 8), device="cuda")
558                b_float32 = torch.rand((8, 8), device="cuda")
559                d_float32 = torch.rand((8, 8), device="cuda")
560
561                with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
562                    e_float16 = torch.mm(a_float32, b_float32)
563                    f_float16 = torch.mm(d_float32, e_float16)
564                return f_float16
565
566        module = MyModule()
567        real = module(torch.tensor([0.5]))
568        real_device = real.device
569        real_dtype = real.dtype
570
571        graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
572        exported = graph(torch.tensor([0.5]))
573        self.assertEqual(exported.device, real_device)
574        self.assertEqual(exported.dtype, real_dtype)
575
576        self.assertEqual(exported.device.type, "cuda")
577        self.assertEqual(exported.device.index, 0)
578        self.assertEqual(exported.dtype, torch.bfloat16)
579
580    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
581    def test_cuda_amp_autocast(self):
582        class MyModule(torch.nn.Module):
583            def forward(self, x):
584                a_float32 = torch.rand((8, 8), device="cuda")
585                b_float32 = torch.rand((8, 8), device="cuda")
586
587                with torch.cuda.amp.autocast(dtype=torch.float64):
588                    c_float64 = torch.mm(a_float32, b_float32)
589                return c_float64
590
591        module = MyModule()
592        real = module(torch.tensor([0.5]))
593        real_device = real.device
594        real_dtype = real.dtype
595
596        graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
597        exported = graph(torch.tensor([0.5]))
598        self.assertEqual(exported.device, real_device)
599        self.assertEqual(exported.dtype, real_dtype)
600
601        self.assertEqual(exported.device.type, "cuda")
602        self.assertEqual(exported.device.index, 0)
603        self.assertEqual(exported.dtype, torch.float64)
604
605    def test_is_autocast_cpu_enabled(self):
606        def fn(a_float32, b_float32):
607            with torch.cpu.amp.autocast(dtype=torch.bfloat16):
608                c_float16 = torch.mm(a_float32, b_float32)
609                if torch.is_autocast_cpu_enabled():
610                    c_float16 = c_float16 + 1
611            return c_float16
612
613        a = torch.rand((8, 8))
614        b = torch.rand((8, 8))
615        ref = fn(a, b)
616        opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
617        res = opt_fn(a, b)
618        self.assertTrue(same(ref, res))
619
620    @unittest.skipIf(
621        not PLATFORM_SUPPORTS_FLASH_ATTENTION or TEST_WITH_ROCM,
622        "Can't run fused SDPA on this platform",
623    )
624    def test_autocast_sdpa(self):
625        class MyModule(torch.nn.Module):
626            def forward(self, query, key, value):
627                with torch.autocast("cpu"):
628                    with torch.autocast("cuda", dtype=torch.float32):
629                        out = F.scaled_dot_product_attention(
630                            query, key, value, None, 0.0, True
631                        )
632                return out
633
634        dtype = torch.float32
635        seq_len_q = 1
636        seq_len_k = 1
637        head_dim = 8
638        query = torch.ones(
639            1, 8, seq_len_q, head_dim, device="cuda", dtype=dtype, requires_grad=True
640        )
641        key = torch.ones(
642            1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True
643        )
644        value = torch.ones(
645            1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True
646        )
647
648        module = MyModule()
649        real = module(query, key, value)
650        real_device = real.device
651        real_dtype = real.dtype
652
653        opt_mod = torch._dynamo.optimize("inductor")(module)
654        compiled = opt_mod(query, key, value)
655
656        self.assertEqual(compiled.device, real_device)
657        self.assertEqual(compiled.dtype, real_dtype)
658
659        self.assertEqual(compiled.device.type, "cuda")
660        self.assertEqual(compiled.device.index, 0)
661        self.assertEqual(compiled.dtype, torch.float32)
662
663    def test_autocast_cpu(self):
664        class MyModule(torch.nn.Module):
665            def forward(self, x):
666                a_float32 = torch.rand((8, 8), device="cpu")
667                b_float32 = torch.rand((8, 8), device="cpu")
668                d_float32 = torch.rand((8, 8), device="cpu")
669
670                with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
671                    e_float16 = torch.mm(a_float32, b_float32)
672                    f_float16 = torch.mm(d_float32, e_float16)
673                return f_float16
674
675        module = MyModule()
676        real = module(torch.tensor([0.5]))
677        real_device = real.device
678        real_dtype = real.dtype
679
680        graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
681        exported = graph(torch.tensor([0.5]))
682        self.assertEqual(exported.device, real_device)
683        self.assertEqual(exported.dtype, real_dtype)
684
685        self.assertEqual(exported.device.type, "cpu")
686        self.assertEqual(exported.dtype, torch.bfloat16)
687
688    def test_autocast_cpu_graph_break(self):
689        class MyModule(torch.nn.Module):
690            def forward(self, x):
691                a_float32 = torch.rand((8, 8), device="cpu")
692                b_float32 = torch.rand((8, 8), device="cpu")
693                torch._dynamo.graph_break()
694                d_float32 = torch.rand((8, 8), device="cpu")
695
696                with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
697                    e_float16 = torch.mm(a_float32, b_float32)
698                    torch._dynamo.graph_break()
699                    f_float16 = torch.mm(d_float32, e_float16)
700                return f_float16
701
702        module = MyModule()
703        real = module(torch.tensor([0.5]))
704        real_device = real.device
705        real_dtype = real.dtype
706
707        opt = torch._dynamo.optimize("eager")(module)
708        res = opt(torch.tensor([0.5]))
709        self.assertEqual(res.device, real_device)
710        self.assertEqual(res.dtype, real_dtype)
711
712        self.assertEqual(res.device.type, "cpu")
713        self.assertEqual(res.dtype, torch.bfloat16)
714
715    def test_autocast_cpu_graph_break_2(self):
716        # Regression for: https://github.com/pytorch/pytorch/issues/93890
717        def fn(x):
718            with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
719                x = torch.mm(x, x)
720                torch._dynamo.graph_break()
721                x = torch.relu(x)
722            return x
723
724        x = torch.rand([4, 4])
725        self.assertEqual(x.dtype, torch.float32)
726        res = fn(x)
727        opt_fn = torch._dynamo.optimize("eager")(fn)
728        opt_res = opt_fn(x)
729        self.assertTrue(torch.allclose(res, opt_res))
730        self.assertEqual(res.dtype, torch.bfloat16)
731        self.assertEqual(opt_res.dtype, torch.bfloat16)
732
733    def test_autocast_cpu_graph_break_inner_fn(self):
734        class MyModule(torch.nn.Module):
735            @staticmethod
736            def mm_breaks(x, y):
737                torch._dynamo.graph_break()
738                return torch.mm(x, y)
739
740            def forward(self, x):
741                a_float32 = torch.rand((8, 8), device="cpu")
742                b_float32 = torch.rand((8, 8), device="cpu")
743
744                with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
745                    torch._dynamo.graph_break()
746                    with torch.autocast(
747                        device_type="cpu", dtype=torch.bfloat16, enabled=False
748                    ):
749                        torch._dynamo.graph_break()
750                        g_float32 = torch.mm(a_float32, b_float32)
751                        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
752                            # Check that nested with non-inlineable function with graph break
753                            torch._dynamo.graph_break()
754                            f_float16_1 = self.mm_breaks(a_float32, b_float32)
755                    # We remember to exit the inner autocast correctly to outer
756                    # even after graph breaks
757                    f_float16 = self.mm_breaks(a_float32, b_float32)
758                    assert f_float16.dtype == f_float16_1.dtype
759                return f_float16, g_float32
760
761        module = MyModule()
762        real_16, real_32 = module(torch.tensor([0.5]))
763        real_device_16 = real_16.device
764        real_dtype_16 = real_16.dtype
765        real_device_32 = real_32.device
766        real_dtype_32 = real_32.dtype
767
768        graph = torch._dynamo.optimize("eager")(module)
769        out_16, out_32 = graph(torch.tensor([0.5]))
770        self.assertEqual(out_16.device, real_device_16)
771        self.assertEqual(out_16.dtype, real_dtype_16)
772        self.assertEqual(out_32.device, real_device_32)
773        self.assertEqual(out_32.dtype, real_dtype_32)
774
775        self.assertEqual(out_16.device.type, "cpu")
776        self.assertEqual(out_16.dtype, torch.bfloat16)
777        self.assertEqual(out_32.device.type, "cpu")
778        self.assertEqual(out_32.dtype, torch.float32)
779
780    def test_autocast_graph_break_method(self):
781        class MyModule(torch.nn.Module):
782            def __init__(self, bias):
783                super().__init__()
784                self.bias = bias
785
786            def mm_not_break(self, x, y):
787                return torch.mm(x, y) + self.bias
788
789            def mm_breaks(self, x, y):
790                torch._dynamo.graph_break()
791                return torch.mm(x, y) + self.bias
792
793            def forward(self, x):
794                a_float32 = torch.rand((8, 8), device="cpu")
795                b_float32 = torch.rand((8, 8), device="cpu")
796
797                with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
798                    with torch.autocast(
799                        device_type="cpu", dtype=torch.bfloat16, enabled=False
800                    ):
801                        g_float32 = torch.mm(a_float32, b_float32)
802                    f_float16 = self.mm_breaks(a_float32, b_float32)
803
804                    assert (
805                        f_float16[0][0] == self.mm_not_break(a_float32, b_float32)[0][0]
806                    )
807                return f_float16, g_float32
808
809        module = MyModule(bias=torch.rand((8, 8), device="cpu", dtype=torch.bfloat16))
810
811        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
812            # Autocast doesn't work on addition, so we need the bias to be `bfloat16`
813            res = torch.rand((8, 8), device="cpu", dtype=torch.float32) + torch.rand(
814                (8, 8), device="cpu", dtype=torch.bfloat16
815            )
816            self.assertEqual(res.dtype, torch.float32)
817
818        real_16, real_32 = module(torch.tensor([0.5]))
819        real_device_16 = real_16.device
820        real_dtype_16 = real_16.dtype
821        real_device_32 = real_32.device
822        real_dtype_32 = real_32.dtype
823
824        graph = torch._dynamo.optimize("eager")(module)
825        out_16, out_32 = graph(torch.tensor([0.5]))
826        self.assertEqual(out_16.device, real_device_16)
827        self.assertEqual(out_16.dtype, real_dtype_16)
828        self.assertEqual(out_32.device, real_device_32)
829        self.assertEqual(out_32.dtype, real_dtype_32)
830
831        self.assertEqual(out_16.device.type, "cpu")
832        self.assertEqual(out_16.dtype, torch.bfloat16)
833        self.assertEqual(out_32.device.type, "cpu")
834        self.assertEqual(out_32.dtype, torch.float32)
835
836    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
837    def test_autocast_float64(self):
838        class MyModule(torch.nn.Module):
839            def forward(self, x):
840                a_float32 = torch.rand((8, 8), device="cuda")
841                b_float32 = torch.rand((8, 8), device="cuda")
842                d_float32 = torch.rand((8, 8), device="cuda")
843
844                with torch.autocast(device_type="cuda", dtype=torch.float64):
845                    e_float64 = torch.mm(a_float32, b_float32)
846                    f_float64 = torch.mm(d_float32, e_float64)
847                return f_float64
848
849        module = MyModule()
850        real = module(torch.tensor([0.5]))
851        real_device = real.device
852        real_dtype = real.dtype
853
854        graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
855        exported = graph(torch.tensor([0.5]))
856        self.assertEqual(exported.device, real_device)
857        self.assertEqual(exported.dtype, real_dtype)
858
859        self.assertEqual(exported.device.index, 0)
860        self.assertEqual(exported.dtype, torch.float64)
861
862    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
863    def test_autocast_device(self):
864        class MyModule(torch.nn.Module):
865            def forward(self, x):
866                a_float32 = torch.rand((8, 8), device="cuda")
867                b_float32 = torch.rand((8, 8), device="cuda")
868                d_float32 = torch.rand((8, 8), device="cuda")
869
870                with torch.autocast("cuda"):
871                    e_float64 = torch.mm(a_float32, b_float32)
872                    f_float64 = torch.mm(d_float32, e_float64)
873                return f_float64
874
875        module = MyModule()
876        real = module(torch.tensor([0.5]))
877        real_device = real.device
878        real_dtype = real.dtype
879
880        graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
881        exported = graph(torch.tensor([0.5]))
882        self.assertEqual(exported.device, real_device)
883        self.assertEqual(exported.dtype, real_dtype)
884
885        self.assertEqual(exported.device.index, 0)
886        self.assertEqual(exported.dtype, torch.float16)
887
888    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
889    def test_autocast_arguments_binding(self):
890        def f1(x):
891            with torch.cuda.amp.autocast(False):
892                x = torch.sin(x + 1)
893            return x
894
895        def f2(x):
896            with torch.cpu.amp.autocast(False):
897                x = torch.cos(x + 1)
898            return x
899
900        x = torch.rand([2, 3])
901        ref1 = f1(x)
902        ref2 = f2(x)
903        opt_f1 = torch.compile(backend="eager")(f1)
904        opt_f2 = torch.compile(backend="eager")(f2)
905        res1 = opt_f1(x)
906        res2 = opt_f2(x)
907        self.assertTrue(same(ref1, res1))
908        self.assertTrue(same(ref2, res2))
909
910    @unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
911    def test_autocast_decorator(self):
912        def autocast_func(orig_func):
913            @torch.amp.autocast(device_type="cuda", dtype=torch.float16)
914            def new_fwd(*args, **kwargs):
915                return orig_func(*args, **kwargs)
916
917            return new_fwd
918
919        def autocast_func_cuda(orig_func):
920            @torch.cuda.amp.autocast(dtype=torch.float16)
921            def new_fwd(*args, **kwargs):
922                return orig_func(*args, **kwargs)
923
924            return new_fwd
925
926        def autocast_func_cpu(orig_func):
927            @torch.cpu.amp.autocast(dtype=torch.float16)
928            def new_fwd(*args, **kwargs):
929                return orig_func(*args, **kwargs)
930
931            return new_fwd
932
933        def mm(a, b):
934            return torch.mm(a, b)
935
936        mm_float16 = autocast_func(mm)
937        mm_float16_cuda = autocast_func_cuda(mm)
938        mm_float16_cpu = autocast_func_cpu(mm)
939
940        def fn(a, b):
941            return mm_float16(a, b), mm_float16_cuda(a, b), mm_float16_cpu(a, b)
942
943        a_float32 = torch.rand((8, 8), device="cuda")
944        b_float32 = torch.rand((8, 8), device="cuda")
945
946        ref = fn(a_float32, b_float32)
947        opt_fn = torch.compile(backend="eager", fullgraph=True)(fn)
948        res = opt_fn(a_float32, b_float32)
949        self.assertTrue(same(ref, res))
950        self.assertTrue(res[0].dtype == torch.float16)
951        self.assertTrue(res[1].dtype == torch.float16)
952
953    def test_generic_ctx_manager_with_graph_break(self):
954        def fn(x):
955            with CustomizedCtxManagerWithGraphBreak(False):
956                # body runs on eager
957                y = x * 2
958                z = y.sin() + 3
959            return z
960
961        x = torch.randn(2, 3)
962        opt_fn = torch.compile(backend="eager", fullgraph=False)(fn)
963        self.assertEqual(fn(x), opt_fn(x))
964
965    def test_return_context_manager(self):
966        @torch.compile(backend="eager", fullgraph=True)
967        def f(x):
968            cm = CustomizedCtxManager(False)
969            with cm:
970                pass
971            return cm
972
973        x = torch.randn(2, 3)
974        cm = f(x)
975        self.assertFalse(cm.mode)
976
977    def test_return_context_manager_with_graph_break(self):
978        @torch.compile(backend="eager", fullgraph=False)
979        def f(x):
980            cm = CustomizedCtxManager(False)
981            torch._dynamo.graph_break()
982            with cm:
983                pass
984            return cm
985
986        x = torch.randn(2, 3)
987        cm = f(x)
988        self.assertFalse(cm.mode)
989
990    def test_generic_context_manager(self):
991        def fn(x):
992            with CustomizedCtxManager(True):
993                x = x + 1
994                if torch.is_grad_enabled():
995                    x = x * 2
996                x = torch.relu(x)
997            return x - 1
998
999        x = torch.rand(2, 3)
1000        cnts = torch._dynamo.testing.CompileCounter()
1001        opt_fn = torch.compile(backend=cnts, fullgraph=True)(fn)
1002
1003        with torch.no_grad():
1004            ref = fn(x)
1005            res = opt_fn(x)
1006            self.assertTrue(same(ref, res))
1007            self.assertEqual(cnts.frame_count, 1)
1008            self.assertEqual(cnts.op_count, 6)
1009
1010        with torch.enable_grad():
1011            ref = fn(x)
1012            res = opt_fn(x)
1013            self.assertTrue(same(ref, res))
1014            self.assertEqual(cnts.frame_count, 2)
1015            self.assertEqual(cnts.op_count, 12)
1016
1017    def test_nested_generic_context_manager(self):
1018        def fn(x):
1019            with CustomizedCtxManager(True):
1020                x = x + 1
1021                if torch.is_grad_enabled():
1022                    x = x * 2
1023                with CustomizedCtxManager(False):
1024                    if torch.is_grad_enabled():
1025                        x = x - 3
1026                    x = x * 1.5
1027                x = torch.relu(x)
1028            return x - 1
1029
1030        x = torch.rand(2, 3)
1031        cnts = torch._dynamo.testing.CompileCounter()
1032        opt_fn = torch.compile(backend=cnts, fullgraph=True)(fn)
1033
1034        with torch.no_grad():
1035            ref = fn(x)
1036            res = opt_fn(x)
1037            self.assertTrue(same(ref, res))
1038            self.assertEqual(cnts.frame_count, 1)
1039            self.assertEqual(cnts.op_count, 9)
1040
1041        with torch.enable_grad():
1042            ref = fn(x)
1043            res = opt_fn(x)
1044            self.assertTrue(same(ref, res))
1045            self.assertEqual(cnts.frame_count, 2)
1046            self.assertEqual(cnts.op_count, 18)
1047
1048    def test_generic_context_manager_with_graph_break(self):
1049        def fn(x):
1050            with CustomizedCtxManager(True):
1051                x = x + 1
1052                if torch.is_grad_enabled():
1053                    x = x * 2
1054                torch._dynamo.graph_break()
1055                x = torch.relu(x)
1056            return x - 1
1057
1058        x = torch.rand(2, 3)
1059        cnts = torch._dynamo.testing.CompileCounter()
1060        opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
1061
1062        with torch.no_grad():
1063            ref = fn(x)
1064            res = opt_fn(x)
1065            self.assertTrue(same(ref, res))
1066            self.assertEqual(cnts.frame_count, 2)
1067            self.assertEqual(cnts.op_count, 2)
1068
1069        with torch.enable_grad():
1070            ref = fn(x)
1071            res = opt_fn(x)
1072            self.assertTrue(same(ref, res))
1073            self.assertEqual(cnts.frame_count, 4)
1074            self.assertEqual(cnts.op_count, 4)
1075
1076    def test_nested_generic_context_manager_with_graph_break(self):
1077        def fn(x):
1078            with CustomizedCtxManager(True):
1079                x = x + 1
1080                if torch.is_grad_enabled():
1081                    x = x * 2
1082                with CustomizedCtxManager(False):
1083                    if torch.is_grad_enabled():
1084                        x = x - 3
1085                    torch._dynamo.graph_break()
1086                    x = x * 1.5
1087                x = torch.relu(x)
1088            return x - 1
1089
1090        x = torch.rand(2, 3)
1091        cnts = torch._dynamo.testing.CompileCounter()
1092        opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
1093
1094        with torch.no_grad():
1095            ref = fn(x)
1096            res = opt_fn(x)
1097            self.assertTrue(same(ref, res))
1098            self.assertEqual(cnts.frame_count, 4)
1099            self.assertEqual(cnts.op_count, 4)
1100
1101        torch._dynamo.reset()
1102        cnts = torch._dynamo.testing.CompileCounter()
1103        opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
1104
1105        with torch.enable_grad():
1106            ref = fn(x)
1107            res = opt_fn(x)
1108            self.assertTrue(same(ref, res))
1109            self.assertEqual(cnts.frame_count, 4)
1110            self.assertEqual(cnts.op_count, 4)
1111
1112    def test_graph_break_inlining_grad(self):
1113        def gn(z):
1114            with torch.no_grad():
1115                torch._dynamo.graph_break()
1116                return torch.sin(z)
1117
1118        def fn(x, y, z):
1119            a = torch.mm(x, y)
1120            z = gn(z)
1121            return a
1122
1123        torch._dynamo.reset()
1124        cnts = torch._dynamo.testing.CompileCounter()
1125        opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
1126        x = torch.randn(4, 4, requires_grad=True)
1127        y = torch.randn(4, 4, requires_grad=True)
1128        z = torch.randn(4)
1129        opt_fn(x, y, z).sum().backward()
1130
1131        self.assertEqual(cnts.frame_count, 2)
1132
1133    def _graph_break_inlining_autocast_test_helper(self, device):
1134        def gn(x, y):
1135            with torch.autocast(device_type=device, dtype=torch.bfloat16):
1136                z = torch.mm(x, y)
1137                torch._dynamo.graph_break()
1138                return torch.sin(z)
1139
1140        def fn(x, y):
1141            z = torch.mm(x, y)
1142            z = z + gn(x, y)
1143            return z
1144
1145        x = torch.rand(3, 3).to(device)
1146        y = torch.rand(3, 3).to(device)
1147        opt_fn = torch.compile(backend="eager")(fn)
1148        ref = fn(x, y)
1149        res = opt_fn(x, y)
1150        self.assertEqual(ref, res)
1151
1152    def test_graph_break_inlining_autocast(self):
1153        for device in ["cuda", "cpu"]:
1154            if device == "cuda" and not (
1155                torch.cuda.is_available() and torch.cuda.is_bf16_supported()
1156            ):
1157                continue
1158            self._graph_break_inlining_autocast_test_helper(device)
1159
1160    def test_disable_saved_tensors_hooks(self):
1161        def fn(z):
1162            @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
1163            def f(x, y):
1164                return x + y
1165
1166            x, y = torch.ones(
1167                1,
1168            ), torch.zeros(
1169                1,
1170            )
1171            return f(x, y)
1172
1173        eager = EagerAndRecordGraphs()
1174        torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
1175
1176        graph = eager.graphs[0]
1177        actual = normalize_gm(graph.print_readable(False))
1178
1179        self.assertExpectedInline(
1180            actual,
1181            """\
1182class GraphModule(torch.nn.Module):
1183    def forward(self):
1184        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported');  _saved_tensors_hooks_disable = None
1185
1186        x: "f32[1]" = torch.ones(1)
1187
1188        y: "f32[1]" = torch.zeros(1)
1189
1190        add: "f32[1]" = x + y;  x = y = None
1191
1192        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
1193        return (add,)
1194""",  # NOQA: B950
1195        )
1196
1197    def test_disable_saved_tensors_hooks_prev_disabled(self):
1198        def fn(z):
1199            @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
1200            def f(x, y):
1201                return x + y
1202
1203            x, y = torch.ones(
1204                1,
1205            ), torch.zeros(
1206                1,
1207            )
1208            return f(x, y)
1209
1210        eager = EagerAndRecordGraphs()
1211        with torch.autograd.graph.disable_saved_tensors_hooks(
1212            "Previously disabled message"
1213        ):
1214            torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
1215
1216        graph = eager.graphs[0]
1217        actual = normalize_gm(graph.print_readable(False))
1218
1219        self.assertExpectedInline(
1220            actual,
1221            """\
1222class GraphModule(torch.nn.Module):
1223    def forward(self):
1224        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported');  _saved_tensors_hooks_disable = None
1225
1226        x: "f32[1]" = torch.ones(1)
1227
1228        y: "f32[1]" = torch.zeros(1)
1229
1230        add: "f32[1]" = x + y;  x = y = None
1231
1232        _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message');  _saved_tensors_hooks_disable_1 = None
1233        return (add,)
1234""",  # NOQA: B950
1235        )
1236
1237    def test_disable_saved_tensors_hooks_prev_disabled_nested(self):
1238        def fn(z):
1239            @torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
1240            def f(x, y):
1241                @torch.autograd.graph.disable_saved_tensors_hooks(
1242                    "This is not supported inner"
1243                )
1244                def inner_fn(x, y):
1245                    return x + y
1246
1247                return inner_fn(x, y) + x
1248
1249            x, y = torch.ones(
1250                1,
1251            ), torch.zeros(
1252                1,
1253            )
1254            return f(x, y)
1255
1256        eager = EagerAndRecordGraphs()
1257        with torch.autograd.graph.disable_saved_tensors_hooks(
1258            "Previously disabled message"
1259        ):
1260            torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
1261
1262        graph = eager.graphs[0]
1263        actual = normalize_gm(graph.print_readable(False))
1264
1265        self.assertExpectedInline(
1266            actual,
1267            """\
1268class GraphModule(torch.nn.Module):
1269    def forward(self):
1270        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported');  _saved_tensors_hooks_disable = None
1271
1272        x: "f32[1]" = torch.ones(1)
1273
1274        y: "f32[1]" = torch.zeros(1)
1275
1276        _saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported inner');  _saved_tensors_hooks_disable_1 = None
1277
1278        add: "f32[1]" = x + y;  y = None
1279
1280        _saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported');  _saved_tensors_hooks_disable_2 = None
1281
1282        add_1: "f32[1]" = add + x;  add = x = None
1283
1284        _saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message');  _saved_tensors_hooks_disable_3 = None
1285        return (add_1,)
1286""",  # NOQA: B950
1287        )
1288
1289    def test_disable_saved_tensors_hooks_graph_break(self):
1290        def fn(x):
1291            with torch.autograd.graph.disable_saved_tensors_hooks(
1292                "This is not supported"
1293            ):
1294                y = x + 1
1295                torch._dynamo.graph_break()
1296                return y * 2
1297
1298        eager = EagerAndRecordGraphs()
1299        torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(()))
1300
1301        def check_graph(actual, expected):
1302            self.assertExpectedInline(actual, expected)
1303
1304        graph = eager.graphs[0]
1305        actual = normalize_gm(graph.print_readable(False))
1306        self.assertExpectedInline(
1307            actual,
1308            """\
1309class GraphModule(torch.nn.Module):
1310    def forward(self, L_x_: "f32[]"):
1311        l_x_ = L_x_
1312
1313        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported');  _saved_tensors_hooks_disable = None
1314
1315        y: "f32[]" = l_x_ + 1;  l_x_ = None
1316
1317        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
1318        return (y,)
1319""",  # NOQA: B950
1320        )
1321
1322        graph = eager.graphs[1]
1323        actual = normalize_gm(graph.print_readable(False))
1324        self.assertExpectedInline(
1325            actual,
1326            """\
1327class GraphModule(torch.nn.Module):
1328    def forward(self, L_y_: "f32[]"):
1329        l_y_ = L_y_
1330
1331        _saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported');  _saved_tensors_hooks_disable = None
1332
1333        mul: "f32[]" = l_y_ * 2;  l_y_ = None
1334
1335        _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable();  _saved_tensors_hooks_enable = None
1336        return (mul,)
1337""",  # NOQA: B950
1338        )
1339
1340    def test_context_wrapping_grad_mode_decorator(self):
1341        ctx_wrappers = [(torch.enable_grad, True), (torch.no_grad, False)]
1342        for call in [True, False]:
1343            for i in range(2):
1344                torch._dynamo.reset()
1345
1346                ctx_wrapper, mode = ctx_wrappers[i]
1347                ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2]
1348
1349                def fn(x):
1350                    def inner_func(x):
1351                        return x.sin()
1352
1353                    with ctx_wrapper_inverse():
1354                        if call:
1355                            inner_func = ctx_wrapper()(inner_func)
1356                        else:
1357                            inner_func = ctx_wrapper(inner_func)
1358
1359                        # Calling no_grad or enabled_grad should not mutate global state
1360                        assert torch.is_grad_enabled() == mode_inverse
1361
1362                    with ctx_wrapper_inverse():
1363                        return inner_func(x)
1364
1365                x = torch.zeros(10, requires_grad=True)
1366                opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1367                self.assertEqual(fn(x), opt_fn(x))
1368                self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
1369
1370    def test_context_wrapping_grad_mode_nested_function_decorator(self):
1371        ctx_wrappers = [(torch.enable_grad, True), (torch.no_grad, False)]
1372
1373        for call in [True, False]:
1374            for i in range(2):
1375                torch._dynamo.reset()
1376
1377                ctx_wrapper, mode = ctx_wrappers[i]
1378                ctx_wrapper_inverse, mode_inverse = ctx_wrappers[(i + 1) % 2]
1379
1380                def fn(x):
1381                    with ctx_wrapper_inverse():
1382                        if call:
1383
1384                            @ctx_wrapper()
1385                            def inner_func(x):
1386                                return x.sin()
1387
1388                        else:
1389
1390                            @ctx_wrapper
1391                            def inner_func(x):
1392                                return x.sin()
1393
1394                        # Calling no_grad or enabled_grad should not mutate global state
1395                        assert torch.is_grad_enabled() == mode_inverse
1396
1397                    with ctx_wrapper_inverse():
1398                        return inner_func(x)
1399
1400                x = torch.zeros(10, requires_grad=True)
1401                opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1402                self.assertEqual(fn(x), opt_fn(x))
1403                self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
1404
1405    def test_context_wrapping_set_grad_enabled_nested_function(self):
1406        modes = [True, False]
1407        for decorator in [True, False]:
1408            for i in range(2):
1409                torch._dynamo.reset()
1410
1411                mode = modes[i]
1412                mode_inverse = modes[(i + 1) % 2]
1413
1414                def fn(x):
1415                    with torch.set_grad_enabled(mode_inverse):
1416                        if decorator:
1417
1418                            @torch.set_grad_enabled(mode)
1419                            def inner_func(x):
1420                                return x.sin()
1421
1422                        else:
1423
1424                            def inner_func(x):
1425                                return x.sin()
1426
1427                            inner_func = torch.set_grad_enabled(mode)(inner_func)
1428
1429                        # Consuming set_grad_enabled by calling it on a function
1430                        # should not mutate global state
1431                        assert torch.is_grad_enabled() == mode_inverse
1432
1433                    with torch.set_grad_enabled(mode_inverse):
1434                        return inner_func(x)
1435
1436            x = torch.zeros(10, requires_grad=True)
1437            opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
1438            self.assertEqual(fn(x), opt_fn(x))
1439            self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
1440
1441    def test_inactive_context_graph_break_local(self):
1442        def fn(x):
1443            x = x + 1
1444            ctx = torch.set_grad_enabled(True)
1445            torch._dynamo.graph_break()
1446            with ctx:
1447                x = x + 1
1448            return x
1449
1450        x = torch.zeros(10, requires_grad=False)
1451        cnts = torch._dynamo.testing.CompileCounter()
1452        opt_fn = torch.compile(fn, backend=cnts)
1453        self.assertEqual(fn(x), opt_fn(x))
1454        self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
1455        self.assertEqual(cnts.frame_count, 2)
1456
1457    def test_inactive_context_graph_break_local_nullctx(self):
1458        import contextlib
1459
1460        # test with context manager that results in None target_values
1461        def fn(x):
1462            x = x + 1
1463            ctx = contextlib.nullcontext()
1464            torch._dynamo.graph_break()
1465            with ctx:
1466                x = x + 1
1467            return x
1468
1469        x = torch.zeros(10, requires_grad=False)
1470        cnts = torch._dynamo.testing.CompileCounter()
1471        opt_fn = torch.compile(fn, backend=cnts)
1472        self.assertEqual(fn(x), opt_fn(x))
1473        self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
1474        self.assertEqual(cnts.frame_count, 2)
1475
1476    def test_inactive_context_graph_break_local_nullctx2(self):
1477        import contextlib
1478
1479        # test with nullcontext where graph break happens
1480        # in an inlined function that returns something
1481        def gn():
1482            torch._dynamo.graph_break()
1483            return [0, 1, 2]
1484
1485        def fn(x):
1486            x = x + 1
1487            ctx = contextlib.nullcontext()
1488            lst = gn()
1489            with ctx:
1490                x = x + lst[1]
1491            return x
1492
1493        x = torch.zeros(10, requires_grad=False)
1494        cnts = torch._dynamo.testing.CompileCounter()
1495        opt_fn = torch.compile(fn, backend=cnts)
1496        self.assertEqual(fn(x), opt_fn(x))
1497        self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
1498        self.assertEqual(cnts.frame_count, 2)
1499
1500    def test_inactive_context_graph_break_stack(self):
1501        def gn(ctx):
1502            torch._dynamo.graph_break()
1503            return ctx
1504
1505        def fn(x):
1506            x = x + 1
1507            ctx = gn(torch.set_grad_enabled(True))
1508            # we expect a graph break on next line as well
1509            with ctx:
1510                x = x + 1
1511            return x
1512
1513        x = torch.zeros(10, requires_grad=False)
1514        cnts = torch._dynamo.testing.CompileCounter()
1515        opt_fn = torch.compile(fn, backend=cnts)
1516        self.assertEqual(fn(x), opt_fn(x))
1517        self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
1518
1519    def test_inactive_context_graph_break_stack2(self):
1520        def gn(x, ctx, y, z, dummy):
1521            with ctx:
1522                return x * y * z
1523
1524        def fn(x):
1525            x = x + 1
1526            x = gn(x, torch.set_grad_enabled(True), 2, 3, torch._dynamo.graph_break())
1527            return x
1528
1529        x = torch.zeros(10, requires_grad=False)
1530        cnts = torch._dynamo.testing.CompileCounter()
1531        opt_fn = torch.compile(fn, backend=cnts)
1532        self.assertEqual(fn(x), opt_fn(x))
1533        self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
1534        self.assertEqual(cnts.frame_count, 2)
1535
1536
1537if __name__ == "__main__":
1538    from torch._dynamo.test_case import run_tests
1539
1540    run_tests()
1541