1# Owner(s): ["module: dynamo"] 2import contextlib 3 4import torch._dynamo.test_case 5import torch._dynamo.testing 6from torch._dynamo.testing import CompileCounter 7from torch.backends.cuda import SDPAParams 8 9 10@contextlib.contextmanager 11def allow_in_graph_sdpa_params(): 12 global SDPAParams 13 try: 14 old = SDPAParams 15 SDPAParams = torch._dynamo.allow_in_graph(SDPAParams) 16 yield 17 finally: 18 SDPAParams = old 19 20 21class TestSDPA(torch._dynamo.test_case.TestCase): 22 def assert_ref_equals_params(self, actual, expected): 23 self.assertIs(actual.query, expected.query) 24 self.assertIs(actual.key, expected.key) 25 self.assertIs(actual.value, expected.value) 26 self.assertIs(actual.attn_mask, expected.attn_mask) 27 28 def test_returns_SDPAParams(self): 29 with allow_in_graph_sdpa_params(): 30 counter = CompileCounter() 31 32 @torch.compile(fullgraph=True, backend=counter) 33 def fn(q, k, v, m): 34 return SDPAParams(q, k, v, m, 0.1, True, False) 35 36 q = torch.randn(10) 37 k = torch.randn(10) 38 v = torch.randn(10) 39 m = torch.randn(10) 40 o = fn(q, k, v, m) 41 self.assertTrue(isinstance(o, SDPAParams)) 42 self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False)) 43 self.assertEqual(counter.frame_count, 1) 44 45 def test_graph_break_SDPAParams(self): 46 with allow_in_graph_sdpa_params(): 47 counter = CompileCounter() 48 49 @torch.compile(backend=counter) 50 def fn(q, k, v, m): 51 z = SDPAParams(q, k, v, m, 0.1, True, False) 52 torch._dynamo.graph_break() 53 return z, q + 1 54 55 q = torch.randn(10) 56 k = torch.randn(10) 57 v = torch.randn(10) 58 m = torch.randn(10) 59 o, _ = fn(q, k, v, m) 60 self.assertTrue(isinstance(o, SDPAParams)) 61 self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False)) 62 self.assertEqual(counter.frame_count, 2) 63 64 def test_input_SDPAParams(self): 65 with allow_in_graph_sdpa_params(): 66 counter = CompileCounter() 67 68 @torch.compile(backend=counter) 69 def fn(sdpap, q): 70 torch._dynamo.graph_break() 71 return sdpap, sdpap.query + q 72 73 q = torch.randn(10) 74 k = torch.randn(10) 75 v = torch.randn(10) 76 m = torch.randn(10) 77 s = SDPAParams(q, k, v, m, 0.1, True, False) 78 o, _ = fn(s, q) 79 self.assertIs(o, s) 80 self.assertEqual(counter.frame_count, 1) 81 82 def test_intermediate_attr_access_SDPAParams(self): 83 with allow_in_graph_sdpa_params(): 84 counter = CompileCounter() 85 86 @torch.compile(fullgraph=True, backend=counter) 87 def fn(q, k, v, m): 88 q += 1 89 z = SDPAParams(q, k, v, m, 0.1, True, False) 90 a = z.query 91 return a + 1, z, q 92 93 q = torch.randn(10) 94 k = torch.randn(10) 95 v = torch.randn(10) 96 m = torch.randn(10) 97 _, o, _ = fn(q, k, v, m) 98 expected = SDPAParams(q, k, v, m, 0.1, True, False) 99 self.assert_ref_equals_params(o, expected) 100 self.assertEqual(counter.frame_count, 1) 101 102 103if __name__ == "__main__": 104 from torch._dynamo.test_case import run_tests 105 106 run_tests() 107