xref: /aosp_15_r20/external/pytorch/test/dynamo/test_sdpa.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: dynamo"]
2*da0073e9SAndroid Build Coastguard Workerimport contextlib
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.test_case
5*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo.testing
6*da0073e9SAndroid Build Coastguard Workerfrom torch._dynamo.testing import CompileCounter
7*da0073e9SAndroid Build Coastguard Workerfrom torch.backends.cuda import SDPAParams
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
11*da0073e9SAndroid Build Coastguard Workerdef allow_in_graph_sdpa_params():
12*da0073e9SAndroid Build Coastguard Worker    global SDPAParams
13*da0073e9SAndroid Build Coastguard Worker    try:
14*da0073e9SAndroid Build Coastguard Worker        old = SDPAParams
15*da0073e9SAndroid Build Coastguard Worker        SDPAParams = torch._dynamo.allow_in_graph(SDPAParams)
16*da0073e9SAndroid Build Coastguard Worker        yield
17*da0073e9SAndroid Build Coastguard Worker    finally:
18*da0073e9SAndroid Build Coastguard Worker        SDPAParams = old
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Workerclass TestSDPA(torch._dynamo.test_case.TestCase):
22*da0073e9SAndroid Build Coastguard Worker    def assert_ref_equals_params(self, actual, expected):
23*da0073e9SAndroid Build Coastguard Worker        self.assertIs(actual.query, expected.query)
24*da0073e9SAndroid Build Coastguard Worker        self.assertIs(actual.key, expected.key)
25*da0073e9SAndroid Build Coastguard Worker        self.assertIs(actual.value, expected.value)
26*da0073e9SAndroid Build Coastguard Worker        self.assertIs(actual.attn_mask, expected.attn_mask)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    def test_returns_SDPAParams(self):
29*da0073e9SAndroid Build Coastguard Worker        with allow_in_graph_sdpa_params():
30*da0073e9SAndroid Build Coastguard Worker            counter = CompileCounter()
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker            @torch.compile(fullgraph=True, backend=counter)
33*da0073e9SAndroid Build Coastguard Worker            def fn(q, k, v, m):
34*da0073e9SAndroid Build Coastguard Worker                return SDPAParams(q, k, v, m, 0.1, True, False)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker            q = torch.randn(10)
37*da0073e9SAndroid Build Coastguard Worker            k = torch.randn(10)
38*da0073e9SAndroid Build Coastguard Worker            v = torch.randn(10)
39*da0073e9SAndroid Build Coastguard Worker            m = torch.randn(10)
40*da0073e9SAndroid Build Coastguard Worker            o = fn(q, k, v, m)
41*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(o, SDPAParams))
42*da0073e9SAndroid Build Coastguard Worker            self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
43*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter.frame_count, 1)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker    def test_graph_break_SDPAParams(self):
46*da0073e9SAndroid Build Coastguard Worker        with allow_in_graph_sdpa_params():
47*da0073e9SAndroid Build Coastguard Worker            counter = CompileCounter()
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend=counter)
50*da0073e9SAndroid Build Coastguard Worker            def fn(q, k, v, m):
51*da0073e9SAndroid Build Coastguard Worker                z = SDPAParams(q, k, v, m, 0.1, True, False)
52*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
53*da0073e9SAndroid Build Coastguard Worker                return z, q + 1
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker            q = torch.randn(10)
56*da0073e9SAndroid Build Coastguard Worker            k = torch.randn(10)
57*da0073e9SAndroid Build Coastguard Worker            v = torch.randn(10)
58*da0073e9SAndroid Build Coastguard Worker            m = torch.randn(10)
59*da0073e9SAndroid Build Coastguard Worker            o, _ = fn(q, k, v, m)
60*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(o, SDPAParams))
61*da0073e9SAndroid Build Coastguard Worker            self.assert_ref_equals_params(o, SDPAParams(q, k, v, m, 0.1, True, False))
62*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter.frame_count, 2)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def test_input_SDPAParams(self):
65*da0073e9SAndroid Build Coastguard Worker        with allow_in_graph_sdpa_params():
66*da0073e9SAndroid Build Coastguard Worker            counter = CompileCounter()
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker            @torch.compile(backend=counter)
69*da0073e9SAndroid Build Coastguard Worker            def fn(sdpap, q):
70*da0073e9SAndroid Build Coastguard Worker                torch._dynamo.graph_break()
71*da0073e9SAndroid Build Coastguard Worker                return sdpap, sdpap.query + q
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker            q = torch.randn(10)
74*da0073e9SAndroid Build Coastguard Worker            k = torch.randn(10)
75*da0073e9SAndroid Build Coastguard Worker            v = torch.randn(10)
76*da0073e9SAndroid Build Coastguard Worker            m = torch.randn(10)
77*da0073e9SAndroid Build Coastguard Worker            s = SDPAParams(q, k, v, m, 0.1, True, False)
78*da0073e9SAndroid Build Coastguard Worker            o, _ = fn(s, q)
79*da0073e9SAndroid Build Coastguard Worker            self.assertIs(o, s)
80*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter.frame_count, 1)
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Worker    def test_intermediate_attr_access_SDPAParams(self):
83*da0073e9SAndroid Build Coastguard Worker        with allow_in_graph_sdpa_params():
84*da0073e9SAndroid Build Coastguard Worker            counter = CompileCounter()
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker            @torch.compile(fullgraph=True, backend=counter)
87*da0073e9SAndroid Build Coastguard Worker            def fn(q, k, v, m):
88*da0073e9SAndroid Build Coastguard Worker                q += 1
89*da0073e9SAndroid Build Coastguard Worker                z = SDPAParams(q, k, v, m, 0.1, True, False)
90*da0073e9SAndroid Build Coastguard Worker                a = z.query
91*da0073e9SAndroid Build Coastguard Worker                return a + 1, z, q
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker            q = torch.randn(10)
94*da0073e9SAndroid Build Coastguard Worker            k = torch.randn(10)
95*da0073e9SAndroid Build Coastguard Worker            v = torch.randn(10)
96*da0073e9SAndroid Build Coastguard Worker            m = torch.randn(10)
97*da0073e9SAndroid Build Coastguard Worker            _, o, _ = fn(q, k, v, m)
98*da0073e9SAndroid Build Coastguard Worker            expected = SDPAParams(q, k, v, m, 0.1, True, False)
99*da0073e9SAndroid Build Coastguard Worker            self.assert_ref_equals_params(o, expected)
100*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(counter.frame_count, 1)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
104*da0073e9SAndroid Build Coastguard Worker    from torch._dynamo.test_case import run_tests
105*da0073e9SAndroid Build Coastguard Worker
106*da0073e9SAndroid Build Coastguard Worker    run_tests()
107