xref: /aosp_15_r20/external/pytorch/test/dynamo/test_pre_dispatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["module: dynamo"]
2import torch
3import torch._dynamo
4import torch._dynamo.test_case
5
6
7class PreDispatchTests(torch._dynamo.test_case.TestCase):
8    def test_no_grad_simple(self):
9        def f(a):
10            b = a.sin()
11            with torch.no_grad():
12                c = b.cos()
13            return b * c.sin()
14
15        f_compiled = torch.compile(f, backend="pre_dispatch_eager")
16
17        a_ref = torch.randn(4, requires_grad=True)
18        a_test = a_ref.clone().detach().requires_grad_(True)
19
20        out_ref = f(a_ref)
21        out_test = f_compiled(a_test)
22        self.assertEqual(out_ref, out_test)
23
24        out_ref.sum().backward()
25        out_test.sum().backward()
26        self.assertEqual(a_ref.grad, a_test.grad)
27
28    def test_enable_grad_and_no_grad(self):
29        def f(a):
30            b = a * 2
31            with torch.no_grad():
32                c = b * 3
33                with torch.enable_grad():
34                    d = c * 4
35                e = d * 5
36            return b + c + d + e
37
38        f_compiled = torch.compile(f, backend="pre_dispatch_eager")
39
40        a_ref = torch.randn(4, requires_grad=True)
41        a_test = a_ref.clone().detach().requires_grad_(True)
42
43        out_ref = f(a_ref)
44        out_test = f_compiled(a_test)
45        self.assertEqual(out_ref, out_test)
46
47        out_ref.sum().backward()
48        out_test.sum().backward()
49        self.assertEqual(a_ref.grad, a_test.grad)
50
51    def test_autocast_simple(self):
52        def f(a):
53            b = a * 2
54            with torch.amp.autocast(device_type="cpu"):
55                c = torch.matmul(b, b)
56            return b + c
57
58        f_compiled = torch.compile(f, backend="pre_dispatch_eager")
59
60        a_ref = torch.randn(4, device="cpu", requires_grad=True)
61        a_test = a_ref.clone().detach().requires_grad_(True)
62
63        out_ref = f(a_ref)
64        out_test = f_compiled(a_test)
65        self.assertEqual(out_ref, out_test)
66
67        out_ref.sum().backward()
68        out_test.sum().backward()
69        self.assertEqual(a_ref.grad, a_test.grad)
70
71
72if __name__ == "__main__":
73    from torch._dynamo.test_case import run_tests
74
75    run_tests()
76