1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"] 2*da0073e9SAndroid Build Coastguard Workerimport contextlib 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case 5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing 6*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounter 7*da0073e9SAndroid Build Coastguard Workerfrom torch.backends.cuda import SDPAParams 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 11*da0073e9SAndroid Build Coastguard Workerdef allow_in_graph_sdpa_params(): 12*da0073e9SAndroid Build Coastguard Worker global SDPAParams 13*da0073e9SAndroid Build Coastguard Worker try: 14*da0073e9SAndroid Build Coastguard Worker old = SDPAParams 15*da0073e9SAndroid Build Coastguard Worker SDPAParams = torch._dynamo.allow_in_graph(SDPAParams) 16*da0073e9SAndroid Build Coastguard Worker yield 17*da0073e9SAndroid Build Coastguard Worker finally: 18*da0073e9SAndroid Build Coastguard Worker SDPAParams = old 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Workerclass TestSDPA(torch._dynamo.test_case.TestCase): 22*da0073e9SAndroid Build Coastguard Worker def assert_ref_equals_params(self, actual, expected): 23*da0073e9SAndroid Build Coastguard Worker self.assertIs(actual.query, expected.query) 24*da0073e9SAndroid Build Coastguard Worker self.assertIs(actual.key, expected.key) 25*da0073e9SAndroid Build Coastguard Worker self.assertIs(actual.value, expected.value) 26*da0073e9SAndroid Build Coastguard Worker self.assertIs(actual.attn_mask, expected.attn_mask) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker def test_returns_SDPAParams(self): 29*da0073e9SAndroid Build Coastguard Worker with allow_in_graph_sdpa_params(): 30*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend=counter) 33*da0073e9SAndroid Build Coastguard Worker def fn(q, k, v, m): 34*da0073e9SAndroid Build Coastguard Worker return SDPAParams(q, k, v, m, 0.1, True, False) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker q = torch.randn(10) 37*da0073e9SAndroid Build Coastguard Worker k = torch.randn(10) 38*da0073e9SAndroid Build Coastguard Worker v = torch.randn(10) 39*da0073e9SAndroid Build Coastguard Worker m = torch.randn(10) 40*da0073e9SAndroid Build Coastguard Worker o = fn(q, k, v, m) 41*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(o, SDPAParams)) 42*da0073e9SAndroid Build Coastguard Worker self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False)) 43*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker def test_graph_break_SDPAParams(self): 46*da0073e9SAndroid Build Coastguard Worker with allow_in_graph_sdpa_params(): 47*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=counter) 50*da0073e9SAndroid Build Coastguard Worker def fn(q, k, v, m): 51*da0073e9SAndroid Build Coastguard Worker z = SDPAParams(q, k, v, m, 0.1, True, False) 52*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 53*da0073e9SAndroid Build Coastguard Worker return z, q + 1 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker q = torch.randn(10) 56*da0073e9SAndroid Build Coastguard Worker k = torch.randn(10) 57*da0073e9SAndroid Build Coastguard Worker v = torch.randn(10) 58*da0073e9SAndroid Build Coastguard Worker m = torch.randn(10) 59*da0073e9SAndroid Build Coastguard Worker o, _ = fn(q, k, v, m) 60*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(o, SDPAParams)) 61*da0073e9SAndroid Build Coastguard Worker self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False)) 62*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 2) 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def test_input_SDPAParams(self): 65*da0073e9SAndroid Build Coastguard Worker with allow_in_graph_sdpa_params(): 66*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker @torch.compile(backend=counter) 69*da0073e9SAndroid Build Coastguard Worker def fn(sdpap, q): 70*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 71*da0073e9SAndroid Build Coastguard Worker return sdpap, sdpap.query + q 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker q = torch.randn(10) 74*da0073e9SAndroid Build Coastguard Worker k = torch.randn(10) 75*da0073e9SAndroid Build Coastguard Worker v = torch.randn(10) 76*da0073e9SAndroid Build Coastguard Worker m = torch.randn(10) 77*da0073e9SAndroid Build Coastguard Worker s = SDPAParams(q, k, v, m, 0.1, True, False) 78*da0073e9SAndroid Build Coastguard Worker o, _ = fn(s, q) 79*da0073e9SAndroid Build Coastguard Worker self.assertIs(o, s) 80*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker def test_intermediate_attr_access_SDPAParams(self): 83*da0073e9SAndroid Build Coastguard Worker with allow_in_graph_sdpa_params(): 84*da0073e9SAndroid Build Coastguard Worker counter = CompileCounter() 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker @torch.compile(fullgraph=True, backend=counter) 87*da0073e9SAndroid Build Coastguard Worker def fn(q, k, v, m): 88*da0073e9SAndroid Build Coastguard Worker q += 1 89*da0073e9SAndroid Build Coastguard Worker z = SDPAParams(q, k, v, m, 0.1, True, False) 90*da0073e9SAndroid Build Coastguard Worker a = z.query 91*da0073e9SAndroid Build Coastguard Worker return a + 1, z, q 92*da0073e9SAndroid Build Coastguard Worker 93*da0073e9SAndroid Build Coastguard Worker q = torch.randn(10) 94*da0073e9SAndroid Build Coastguard Worker k = torch.randn(10) 95*da0073e9SAndroid Build Coastguard Worker v = torch.randn(10) 96*da0073e9SAndroid Build Coastguard Worker m = torch.randn(10) 97*da0073e9SAndroid Build Coastguard Worker _, o, _ = fn(q, k, v, m) 98*da0073e9SAndroid Build Coastguard Worker expected = SDPAParams(q, k, v, m, 0.1, True, False) 99*da0073e9SAndroid Build Coastguard Worker self.assert_ref_equals_params(o, expected) 100*da0073e9SAndroid Build Coastguard Worker self.assertEqual(counter.frame_count, 1) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker 103*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 104*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.test_case import run_tests 105*da0073e9SAndroid Build Coastguard Worker 106*da0073e9SAndroid Build Coastguard Worker run_tests() 107