1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: cuda graphs"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport unittest 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Workerimport torch 7*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 8*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.config 9*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 11*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import same 12*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import TEST_CUDA_GRAPH 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Workerdef composed(*decs): 16*da0073e9SAndroid Build Coastguard Worker def deco(f): 17*da0073e9SAndroid Build Coastguard Worker for dec in reversed(decs): 18*da0073e9SAndroid Build Coastguard Worker f = dec(f) 19*da0073e9SAndroid Build Coastguard Worker return f 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker return deco 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Workerdef assert_aot_autograd_counter(ok=True): 25*da0073e9SAndroid Build Coastguard Worker def deco(f): 26*da0073e9SAndroid Build Coastguard Worker @functools.wraps(f) 27*da0073e9SAndroid Build Coastguard Worker def wrap(self, *args, **kwargs): 28*da0073e9SAndroid Build Coastguard Worker torch._dynamo.utils.counters.clear() 29*da0073e9SAndroid Build Coastguard Worker r = f(self, *args, **kwargs) 30*da0073e9SAndroid Build Coastguard Worker c_ok = torch._dynamo.utils.counters["aot_autograd"]["ok"] 31*da0073e9SAndroid Build Coastguard Worker c_not_ok = torch._dynamo.utils.counters["aot_autograd"]["not_ok"] 32*da0073e9SAndroid Build Coastguard Worker if ok: 33*da0073e9SAndroid Build Coastguard Worker self.assertGreater(c_ok, 0) 34*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_not_ok, 0) 35*da0073e9SAndroid Build Coastguard Worker else: 36*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c_ok, 0) 37*da0073e9SAndroid Build Coastguard Worker self.assertGreater(c_not_ok, 0) 38*da0073e9SAndroid Build Coastguard Worker return r 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker return wrap 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker return deco 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Workerdef patch_all(ok=True): 46*da0073e9SAndroid Build Coastguard Worker return composed( 47*da0073e9SAndroid Build Coastguard Worker torch._dynamo.config.patch( 48*da0073e9SAndroid Build Coastguard Worker verify_correctness=True, automatic_dynamic_shapes=True 49*da0073e9SAndroid Build Coastguard Worker ), 50*da0073e9SAndroid Build Coastguard Worker assert_aot_autograd_counter(ok), 51*da0073e9SAndroid Build Coastguard Worker ) 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard WorkerN_ITERS = 5 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker@unittest.skipIf(not torch.cuda.is_available(), "these tests require cuda") 58*da0073e9SAndroid Build Coastguard Workerclass TestAotCudagraphs(torch._dynamo.test_case.TestCase): 59*da0073e9SAndroid Build Coastguard Worker @patch_all() 60*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 61*da0073e9SAndroid Build Coastguard Worker def model(x, y): 62*da0073e9SAndroid Build Coastguard Worker return (x + y) * y 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("cudagraphs") 65*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 66*da0073e9SAndroid Build Coastguard Worker for i in range(N_ITERS): 67*da0073e9SAndroid Build Coastguard Worker loss = model(x, y).sum() 68*da0073e9SAndroid Build Coastguard Worker loss.backward() 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device="cuda", requires_grad=True) 71*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, device="cuda") 72*da0073e9SAndroid Build Coastguard Worker fn(x, y) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker @patch_all() 75*da0073e9SAndroid Build Coastguard Worker def test_dtoh(self): 76*da0073e9SAndroid Build Coastguard Worker def model(x, y): 77*da0073e9SAndroid Build Coastguard Worker a = x + y 78*da0073e9SAndroid Build Coastguard Worker b = a.cpu() * 3 79*da0073e9SAndroid Build Coastguard Worker return b 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("cudagraphs") 82*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 83*da0073e9SAndroid Build Coastguard Worker for i in range(N_ITERS): 84*da0073e9SAndroid Build Coastguard Worker loss = model(x, y).sum() 85*da0073e9SAndroid Build Coastguard Worker loss.backward() 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device="cuda", requires_grad=True) 88*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, device="cuda") 89*da0073e9SAndroid Build Coastguard Worker fn(x, y) 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker @patch_all() 92*da0073e9SAndroid Build Coastguard Worker def test_htod(self): 93*da0073e9SAndroid Build Coastguard Worker def model(x, y): 94*da0073e9SAndroid Build Coastguard Worker a = x + y 95*da0073e9SAndroid Build Coastguard Worker return a * 3 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("cudagraphs") 98*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 99*da0073e9SAndroid Build Coastguard Worker for i in range(N_ITERS): 100*da0073e9SAndroid Build Coastguard Worker loss = model(x, y).sum() 101*da0073e9SAndroid Build Coastguard Worker loss.backward() 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device="cuda", requires_grad=True) 104*da0073e9SAndroid Build Coastguard Worker y = torch.randn((), device="cpu") 105*da0073e9SAndroid Build Coastguard Worker fn(x, y) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker def test_mutate_input(self): 108*da0073e9SAndroid Build Coastguard Worker def model(x, y): 109*da0073e9SAndroid Build Coastguard Worker y.add_(3) 110*da0073e9SAndroid Build Coastguard Worker return x * y 111*da0073e9SAndroid Build Coastguard Worker 112*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("cudagraphs") 113*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 114*da0073e9SAndroid Build Coastguard Worker for i in range(N_ITERS): 115*da0073e9SAndroid Build Coastguard Worker with self.subTest(i): 116*da0073e9SAndroid Build Coastguard Worker y_orig = y.clone() 117*da0073e9SAndroid Build Coastguard Worker loss = model(x, y).sum() 118*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(y, y_orig + 3)) 119*da0073e9SAndroid Build Coastguard Worker loss.backward() 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, device="cuda", requires_grad=True) 122*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, device="cuda") 123*da0073e9SAndroid Build Coastguard Worker fn(x, y) 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker @patch_all() 126*da0073e9SAndroid Build Coastguard Worker def test_mutate_constant(self): 127*da0073e9SAndroid Build Coastguard Worker def model(x, y): 128*da0073e9SAndroid Build Coastguard Worker c = torch.tensor(1) 129*da0073e9SAndroid Build Coastguard Worker c.add_(2) 130*da0073e9SAndroid Build Coastguard Worker return x * y * 0 + c 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("cudagraphs") 133*da0073e9SAndroid Build Coastguard Worker def fn(x, y): 134*da0073e9SAndroid Build Coastguard Worker for i in range(N_ITERS): 135*da0073e9SAndroid Build Coastguard Worker with self.subTest(i): 136*da0073e9SAndroid Build Coastguard Worker loss = model(x, y).sum() 137*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(loss, torch.tensor(3.0, device="cuda"))) 138*da0073e9SAndroid Build Coastguard Worker loss.backward() 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker x = torch.randn(1, device="cuda", requires_grad=True) 141*da0073e9SAndroid Build Coastguard Worker y = torch.randn(1, device="cuda") 142*da0073e9SAndroid Build Coastguard Worker fn(x, y) 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker @patch_all() 145*da0073e9SAndroid Build Coastguard Worker def test_factory(self): 146*da0073e9SAndroid Build Coastguard Worker def model(y): 147*da0073e9SAndroid Build Coastguard Worker x = torch.zeros(3, device="cuda:0") 148*da0073e9SAndroid Build Coastguard Worker x.add_(3) 149*da0073e9SAndroid Build Coastguard Worker return x * y 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("cudagraphs") 152*da0073e9SAndroid Build Coastguard Worker def fn(y): 153*da0073e9SAndroid Build Coastguard Worker for i in range(N_ITERS): 154*da0073e9SAndroid Build Coastguard Worker with self.subTest(i): 155*da0073e9SAndroid Build Coastguard Worker loss = model(y).sum() 156*da0073e9SAndroid Build Coastguard Worker loss.backward() 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker y = torch.randn(3, device="cuda:0", requires_grad=True) 159*da0073e9SAndroid Build Coastguard Worker fn(y) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker @patch_all() 162*da0073e9SAndroid Build Coastguard Worker def test_mutated_metadata(self): 163*da0073e9SAndroid Build Coastguard Worker # more tortured example at 164*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/81385 165*da0073e9SAndroid Build Coastguard Worker def model(x): 166*da0073e9SAndroid Build Coastguard Worker x = x.clone() 167*da0073e9SAndroid Build Coastguard Worker x.resize_(20) 168*da0073e9SAndroid Build Coastguard Worker x.fill_(2) 169*da0073e9SAndroid Build Coastguard Worker return x 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("cudagraphs") 172*da0073e9SAndroid Build Coastguard Worker def fn(x): 173*da0073e9SAndroid Build Coastguard Worker for i in range(N_ITERS): 174*da0073e9SAndroid Build Coastguard Worker with self.subTest(i): 175*da0073e9SAndroid Build Coastguard Worker rx = model(x) 176*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) 177*da0073e9SAndroid Build Coastguard Worker 178*da0073e9SAndroid Build Coastguard Worker x = torch.empty(0, device="cuda:0") 179*da0073e9SAndroid Build Coastguard Worker fn(x) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker @patch_all() 182*da0073e9SAndroid Build Coastguard Worker def test_dead_fill(self): 183*da0073e9SAndroid Build Coastguard Worker def model(x): 184*da0073e9SAndroid Build Coastguard Worker x = x.clone() 185*da0073e9SAndroid Build Coastguard Worker y = x[0:0] 186*da0073e9SAndroid Build Coastguard Worker x.fill_(2) 187*da0073e9SAndroid Build Coastguard Worker y.fill_(3) 188*da0073e9SAndroid Build Coastguard Worker return x, y 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker @torch._dynamo.optimize("cudagraphs") 191*da0073e9SAndroid Build Coastguard Worker def fn(x): 192*da0073e9SAndroid Build Coastguard Worker for i in range(N_ITERS): 193*da0073e9SAndroid Build Coastguard Worker with self.subTest(i): 194*da0073e9SAndroid Build Coastguard Worker rx, ry = model(x) 195*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(rx, torch.full((20,), 2.0, device="cuda:0"))) 196*da0073e9SAndroid Build Coastguard Worker self.assertTrue(same(ry, torch.empty(0, device="cuda:0"))) 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker x = torch.empty(20, device="cuda:0") 199*da0073e9SAndroid Build Coastguard Worker fn(x) 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 203*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 204*da0073e9SAndroid Build Coastguard Worker 205*da0073e9SAndroid Build Coastguard Worker if not TEST_CUDA_GRAPH: 206*da0073e9SAndroid Build Coastguard Worker if __name__ == "__main__": 207*da0073e9SAndroid Build Coastguard Worker import sys 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker sys.exit(0) 210*da0073e9SAndroid Build Coastguard Worker raise unittest.SkipTest("cuda graph test is skipped") 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker run_tests() 213