xref: /aosp_15_r20/external/pytorch/test/dynamo/test_cudagraphs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda graphs"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport unittest
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.config
9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
11*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import same
12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TEST_CUDA_GRAPH
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerdef composed(*decs):
16*da0073e9SAndroid Build Coastguard Worker    def deco(f):
17*da0073e9SAndroid Build Coastguard Worker        for dec in reversed(decs):
18*da0073e9SAndroid Build Coastguard Worker            f = dec(f)
19*da0073e9SAndroid Build Coastguard Worker        return f
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    return deco
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Workerdef assert_aot_autograd_counter(ok=True):
25*da0073e9SAndroid Build Coastguard Worker    def deco(f):
26*da0073e9SAndroid Build Coastguard Worker        @functools.wraps(f)
27*da0073e9SAndroid Build Coastguard Worker        def wrap(self, *args, **kwargs):
28*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.utils.counters.clear()
29*da0073e9SAndroid Build Coastguard Worker            r = f(self, *args, **kwargs)
30*da0073e9SAndroid Build Coastguard Worker            c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"]
31*da0073e9SAndroid Build Coastguard Worker            c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"]
32*da0073e9SAndroid Build Coastguard Worker            if ok:
33*da0073e9SAndroid Build Coastguard Worker                self.assertGreater(c_ok, 0)
34*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c_not_ok, 0)
35*da0073e9SAndroid Build Coastguard Worker            else:
36*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(c_ok, 0)
37*da0073e9SAndroid Build Coastguard Worker                self.assertGreater(c_not_ok, 0)
38*da0073e9SAndroid Build Coastguard Worker            return r
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker        return wrap
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    return deco
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Workerdef patch_all(ok=True):
46*da0073e9SAndroid Build Coastguard Worker    return composed(
47*da0073e9SAndroid Build Coastguard Worker        torch._dynamo.config.patch(
48*da0073e9SAndroid Build Coastguard Worker            verify_correctness=True, automatic_dynamic_shapes=True
49*da0073e9SAndroid Build Coastguard Worker        ),
50*da0073e9SAndroid Build Coastguard Worker        assert_aot_autograd_counter(ok),
51*da0073e9SAndroid Build Coastguard Worker    )
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard WorkerN_ITERS = 5
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda")
58*da0073e9SAndroid Build Coastguard Workerclass TestAotCudagraphs(torch._dynamo.test_case.TestCase):
59*da0073e9SAndroid Build Coastguard Worker    @patch_all()
60*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
61*da0073e9SAndroid Build Coastguard Worker        def model(x, y):
62*da0073e9SAndroid Build Coastguard Worker            return (x + y) * y
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("cudagraphs")
65*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
66*da0073e9SAndroid Build Coastguard Worker            for i in range(N_ITERS):
67*da0073e9SAndroid Build Coastguard Worker                loss = model(x, y).sum()
68*da0073e9SAndroid Build Coastguard Worker                loss.backward()
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device="cuda", requires_grad=True)
71*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, device="cuda")
72*da0073e9SAndroid Build Coastguard Worker        fn(x, y)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker    @patch_all()
75*da0073e9SAndroid Build Coastguard Worker    def test_dtoh(self):
76*da0073e9SAndroid Build Coastguard Worker        def model(x, y):
77*da0073e9SAndroid Build Coastguard Worker            a = x + y
78*da0073e9SAndroid Build Coastguard Worker            b = a.cpu() * 3
79*da0073e9SAndroid Build Coastguard Worker            return b
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("cudagraphs")
82*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
83*da0073e9SAndroid Build Coastguard Worker            for i in range(N_ITERS):
84*da0073e9SAndroid Build Coastguard Worker                loss = model(x, y).sum()
85*da0073e9SAndroid Build Coastguard Worker                loss.backward()
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device="cuda", requires_grad=True)
88*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, device="cuda")
89*da0073e9SAndroid Build Coastguard Worker        fn(x, y)
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    @patch_all()
92*da0073e9SAndroid Build Coastguard Worker    def test_htod(self):
93*da0073e9SAndroid Build Coastguard Worker        def model(x, y):
94*da0073e9SAndroid Build Coastguard Worker            a = x + y
95*da0073e9SAndroid Build Coastguard Worker            return a * 3
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("cudagraphs")
98*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
99*da0073e9SAndroid Build Coastguard Worker            for i in range(N_ITERS):
100*da0073e9SAndroid Build Coastguard Worker                loss = model(x, y).sum()
101*da0073e9SAndroid Build Coastguard Worker                loss.backward()
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device="cuda", requires_grad=True)
104*da0073e9SAndroid Build Coastguard Worker        y = torch.randn((), device="cpu")
105*da0073e9SAndroid Build Coastguard Worker        fn(x, y)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    def test_mutate_input(self):
108*da0073e9SAndroid Build Coastguard Worker        def model(x, y):
109*da0073e9SAndroid Build Coastguard Worker            y.add_(3)
110*da0073e9SAndroid Build Coastguard Worker            return x * y
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("cudagraphs")
113*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
114*da0073e9SAndroid Build Coastguard Worker            for i in range(N_ITERS):
115*da0073e9SAndroid Build Coastguard Worker                with self.subTest(i):
116*da0073e9SAndroid Build Coastguard Worker                    y_orig = y.clone()
117*da0073e9SAndroid Build Coastguard Worker                    loss = model(x, y).sum()
118*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(same(y, y_orig + 3))
119*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, device="cuda", requires_grad=True)
122*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, device="cuda")
123*da0073e9SAndroid Build Coastguard Worker        fn(x, y)
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker    @patch_all()
126*da0073e9SAndroid Build Coastguard Worker    def test_mutate_constant(self):
127*da0073e9SAndroid Build Coastguard Worker        def model(x, y):
128*da0073e9SAndroid Build Coastguard Worker            c = torch.tensor(1)
129*da0073e9SAndroid Build Coastguard Worker            c.add_(2)
130*da0073e9SAndroid Build Coastguard Worker            return x * y * 0 + c
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("cudagraphs")
133*da0073e9SAndroid Build Coastguard Worker        def fn(x, y):
134*da0073e9SAndroid Build Coastguard Worker            for i in range(N_ITERS):
135*da0073e9SAndroid Build Coastguard Worker                with self.subTest(i):
136*da0073e9SAndroid Build Coastguard Worker                    loss = model(x, y).sum()
137*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(same(loss, torch.tensor(3.0, device="cuda")))
138*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(1, device="cuda", requires_grad=True)
141*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(1, device="cuda")
142*da0073e9SAndroid Build Coastguard Worker        fn(x, y)
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    @patch_all()
145*da0073e9SAndroid Build Coastguard Worker    def test_factory(self):
146*da0073e9SAndroid Build Coastguard Worker        def model(y):
147*da0073e9SAndroid Build Coastguard Worker            x = torch.zeros(3, device="cuda:0")
148*da0073e9SAndroid Build Coastguard Worker            x.add_(3)
149*da0073e9SAndroid Build Coastguard Worker            return x * y
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("cudagraphs")
152*da0073e9SAndroid Build Coastguard Worker        def fn(y):
153*da0073e9SAndroid Build Coastguard Worker            for i in range(N_ITERS):
154*da0073e9SAndroid Build Coastguard Worker                with self.subTest(i):
155*da0073e9SAndroid Build Coastguard Worker                    loss = model(y).sum()
156*da0073e9SAndroid Build Coastguard Worker                    loss.backward()
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(3, device="cuda:0", requires_grad=True)
159*da0073e9SAndroid Build Coastguard Worker        fn(y)
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    @patch_all()
162*da0073e9SAndroid Build Coastguard Worker    def test_mutated_metadata(self):
163*da0073e9SAndroid Build Coastguard Worker        # more tortured example at
164*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/81385
165*da0073e9SAndroid Build Coastguard Worker        def model(x):
166*da0073e9SAndroid Build Coastguard Worker            x = x.clone()
167*da0073e9SAndroid Build Coastguard Worker            x.resize_(20)
168*da0073e9SAndroid Build Coastguard Worker            x.fill_(2)
169*da0073e9SAndroid Build Coastguard Worker            return x
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("cudagraphs")
172*da0073e9SAndroid Build Coastguard Worker        def fn(x):
173*da0073e9SAndroid Build Coastguard Worker            for i in range(N_ITERS):
174*da0073e9SAndroid Build Coastguard Worker                with self.subTest(i):
175*da0073e9SAndroid Build Coastguard Worker                    rx = model(x)
176*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
177*da0073e9SAndroid Build Coastguard Worker
178*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(0, device="cuda:0")
179*da0073e9SAndroid Build Coastguard Worker        fn(x)
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Worker    @patch_all()
182*da0073e9SAndroid Build Coastguard Worker    def test_dead_fill(self):
183*da0073e9SAndroid Build Coastguard Worker        def model(x):
184*da0073e9SAndroid Build Coastguard Worker            x = x.clone()
185*da0073e9SAndroid Build Coastguard Worker            y = x[0:0]
186*da0073e9SAndroid Build Coastguard Worker            x.fill_(2)
187*da0073e9SAndroid Build Coastguard Worker            y.fill_(3)
188*da0073e9SAndroid Build Coastguard Worker            return x, y
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        @torch._dynamo.optimize("cudagraphs")
191*da0073e9SAndroid Build Coastguard Worker        def fn(x):
192*da0073e9SAndroid Build Coastguard Worker            for i in range(N_ITERS):
193*da0073e9SAndroid Build Coastguard Worker                with self.subTest(i):
194*da0073e9SAndroid Build Coastguard Worker                    rx, ry = model(x)
195*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0")))
196*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(same(ry, torch.empty(0, device="cuda:0")))
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard Worker        x = torch.empty(20, device="cuda:0")
199*da0073e9SAndroid Build Coastguard Worker        fn(x)
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker
202*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
203*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker    if not TEST_CUDA_GRAPH:
206*da0073e9SAndroid Build Coastguard Worker        if __name__ == "__main__":
207*da0073e9SAndroid Build Coastguard Worker            import sys
208*da0073e9SAndroid Build Coastguard Worker
209*da0073e9SAndroid Build Coastguard Worker            sys.exit(0)
210*da0073e9SAndroid Build Coastguard Worker        raise unittest.SkipTest("cuda graph test is skipped")
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker    run_tests()
213