xref: /aosp_15_r20/external/pytorch/test/jit/test_batch_mm.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import torch
4from torch.testing import FileCheck
5from torch.testing._internal.jit_utils import JitTestCase
6
7
8if __name__ == "__main__":
9    raise RuntimeError(
10        "This test file is not meant to be run directly, use:\n\n"
11        "\tpython test/test_jit.py TESTNAME\n\n"
12        "instead."
13    )
14
15
16class TestBatchMM(JitTestCase):
17    @staticmethod
18    def _get_test_tensors(n: int):
19        return [
20            torch.tensor([[1 + x, 2 + x, 3 + x], [4 + x, 5 + x, 6 + x]])
21            if x % 2 == 0
22            else torch.tensor([[1 + x, 2 + x], [3 + x, 4 + x], [5 + x, 6 + x]])
23            for x in range(n)
24        ]
25
26    def test_batch_mm_no_mutation(self):
27        def test_batch_mm(
28            T1: torch.Tensor,
29            T2: torch.Tensor,
30            T3: torch.Tensor,
31            T4: torch.Tensor,
32            T5: torch.Tensor,
33            T6: torch.Tensor,
34            T7: torch.Tensor,
35            T8: torch.Tensor,
36        ):
37            return (
38                torch.mm(T1, T2)
39                + torch.mm(T3, T4)
40                + torch.mm(T5, T6)
41                + torch.mm(T7, T8)
42            )
43
44        test_batch_mm_scripted = torch.jit.script(test_batch_mm)
45
46        tensors = TestBatchMM._get_test_tensors(8)
47        expected = test_batch_mm(*tensors)
48
49        FileCheck().check_count("aten::mm", 4, exactly=True).run(
50            test_batch_mm_scripted.graph
51        )
52        self.run_pass("batch_mm", test_batch_mm_scripted.graph)
53        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
54            test_batch_mm_scripted.graph
55        )
56
57        actual = test_batch_mm_scripted(*tensors)
58        self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
59
60    def test_batch_mm_permitted_mutation(self):
61        def test_batch_mm(
62            T1: torch.Tensor,
63            T2: torch.Tensor,
64            T3: torch.Tensor,
65            T4: torch.Tensor,
66            T5: torch.Tensor,
67            T6: torch.Tensor,
68            T7: torch.Tensor,
69            T8: torch.Tensor,
70        ):
71            result = {}
72            result["product"] = (
73                torch.mm(T1, T2)
74                + torch.mm(T3, T4)
75                + torch.mm(T5, T6)
76                + torch.mm(T7, T8)
77            )
78            result["constant"] = torch.tensor([42.0])
79            return result
80
81        test_batch_mm_scripted = torch.jit.script(test_batch_mm)
82
83        tensors = TestBatchMM._get_test_tensors(8)
84        expected = test_batch_mm(*tensors)
85
86        FileCheck().check_count("aten::mm", 4, exactly=True).run(
87            test_batch_mm_scripted.graph
88        )
89        self.run_pass("batch_mm", test_batch_mm_scripted.graph)
90        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).run(
91            test_batch_mm_scripted.graph
92        )
93
94        actual = test_batch_mm_scripted(*tensors)
95        self.assertEqual(expected, actual, atol=1e-9, rtol=1e-9)
96
97    def test_batch_mm_prohibited_mutation(self):
98        @torch.jit.script
99        def test_batch_mm(n: int):
100            T1 = torch.zeros((n, n))
101            T2 = torch.zeros((n, n))
102            T3 = torch.zeros((n, n))
103            T4 = torch.zeros((n, n))
104            T5 = torch.zeros((n, n))
105            T6 = torch.zeros((n, n))
106            T7 = torch.zeros((n, n))
107            T8 = torch.zeros((n, n))
108            torch.relu_(T1)
109            result = (
110                torch.mm(T1, T2)
111                + torch.mm(T3, T4)
112                + torch.mm(T5, T6)
113                + torch.mm(T7, T8)
114            )
115            return result
116
117        FileCheck().check_count("aten::mm", 4, exactly=True).run(test_batch_mm.graph)
118        self.run_pass("batch_mm", test_batch_mm.graph)
119        FileCheck().check_count("aten::mm", 4, exactly=True).check_not(
120            "prim::MMTreeReduce"
121        ).run(test_batch_mm.graph)
122
123    def test_batch_mm_prohibited_mutation_multiple_adds(self):
124        @torch.jit.script
125        def test_batch_mm(n: int):
126            T1 = torch.zeros((n, n))
127            T2 = torch.zeros((n, n))
128            T3 = torch.zeros((n, n))
129            T4 = torch.zeros((n, n))
130            T5 = torch.zeros((n, n))
131            T6 = torch.zeros((n, n))
132            T7 = torch.zeros((n, n))
133            T8 = torch.zeros((n, n))
134            T9 = torch.zeros((n, n))
135            T10 = torch.zeros((n, n))
136            torch.relu_(T1)
137            result = {}
138            result["no_mutated_parameters"] = (
139                torch.mm(T2, T3)
140                + torch.mm(T4, T5)
141                + torch.mm(T6, T7)
142                + torch.mm(T8, T9)
143            )
144            result["all_parameters"] = (
145                torch.mm(T1, T2)
146                + torch.mm(T3, T4)
147                + torch.mm(T5, T6)
148                + torch.mm(T7, T8)
149                + torch.mm(T9, T10)
150            )
151            return result
152
153        self.run_pass("batch_mm", test_batch_mm.graph)
154        FileCheck().check_count("prim::MMTreeReduce", 1, exactly=True).check_count(
155            "aten::mm", 5, exactly=True
156        ).run(test_batch_mm.graph)
157
158    def test_batch_mm_prohibited_mutation_if_node(self):
159        @torch.jit.script
160        def test_batch_mm(n: int, use_t1: bool):
161            T1 = torch.zeros((n, n))
162            T2 = torch.zeros((n, n))
163            T3 = torch.zeros((n, n))
164            T4 = torch.zeros((n, n))
165            T5 = torch.zeros((n, n))
166            T6 = torch.zeros((n, n))
167            T7 = torch.zeros((n, n))
168            T8 = torch.zeros((n, n))
169            T9 = torch.zeros((n, n))
170            T10 = torch.zeros((n, n))
171            if use_t1:
172                torch.relu_(T1)
173                return (
174                    torch.mm(T1, T2)
175                    + torch.mm(T3, T4)
176                    + torch.mm(T5, T6)
177                    + torch.mm(T7, T8)
178                    + torch.mm(T9, T10)
179                )
180            else:
181                return (
182                    torch.mm(T2, T3)
183                    + torch.mm(T4, T5)
184                    + torch.mm(T6, T7)
185                    + torch.mm(T8, T9)
186                )
187
188        self.run_pass("batch_mm", test_batch_mm.graph)
189        FileCheck().check_count("aten::mm", 5, exactly=True).check_count(
190            "prim::MMTreeReduce", 1, exactly=True
191        ).run(test_batch_mm.graph)
192
193    def test_batch_mm_side_permitted_mutation(self):
194        @torch.jit.script
195        def test_batch_mm(n: int):
196            result = {}
197            A = torch.zeros((n, n))
198            T1 = torch.zeros((n, n))
199            T2 = torch.zeros((n, n))
200            T3 = torch.zeros((n, n))
201            T4 = torch.zeros((n, n))
202            T5 = torch.zeros((n, n))
203            T6 = torch.zeros((n, n))
204            T7 = torch.zeros((n, n))
205            T8 = torch.zeros((n, n))
206            result["T1"] = torch.mm(A, T1)
207            result["T2"] = torch.mm(A, T2)
208            result["T3"] = torch.mm(A, T3)
209            result["T4"] = torch.mm(A, T4)
210            result["T5"] = torch.mm(A, T5)
211            result["T6"] = torch.mm(A, T6)
212            result["T7"] = torch.mm(A, T7)
213            result["T8"] = torch.mm(A, T8)
214            return result
215
216        FileCheck().check_count("aten::mm", 8, exactly=True).run(test_batch_mm.graph)
217        self.run_pass("batch_mm", test_batch_mm.graph)
218        FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).check_not(
219            "aten::mm"
220        ).run(test_batch_mm.graph)
221
222    def test_batch_mm_side_prohibited_mutation_uncommon_side(self):
223        @torch.jit.script
224        def test_batch_mm(n: int):
225            A = torch.zeros((n, n))
226            T1 = torch.zeros((n, n))
227            T2 = torch.zeros((n, n))
228            T3 = torch.zeros((n, n))
229            T4 = torch.zeros((n, n))
230            T5 = torch.zeros((n, n))
231            T6 = torch.zeros((n, n))
232            T7 = torch.zeros((n, n))
233            T8 = torch.zeros((n, n))
234            T9 = torch.zeros((n, n))
235            T10 = torch.zeros((n, n))
236            torch.relu_(T1)
237            result = {}
238            result["T1"] = torch.mm(A, T1)
239            result["T2"] = torch.mm(A, T2)
240            result["T3"] = torch.mm(A, T3)
241            result["T4"] = torch.mm(A, T4)
242            result["T5"] = torch.mm(A, T5)
243            result["T6"] = torch.mm(A, T6)
244            result["T7"] = torch.mm(A, T7)
245            result["T8"] = torch.mm(A, T8)
246            result["T9"] = torch.mm(A, T9)
247            result["T10"] = torch.mm(A, T10)
248            return result
249
250        FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
251        self.run_pass("batch_mm", test_batch_mm.graph)
252
253        FileCheck().check_count("aten::mm", 1, exactly=True).run(test_batch_mm.graph)
254        FileCheck().check_count("prim::MMBatchSide", 1, exactly=True).run(
255            test_batch_mm.graph
256        )
257
258    def test_batch_mm_side_prohibited_mutation_common_side(self):
259        @torch.jit.script
260        def test_batch_mm(n: int):
261            A = torch.zeros((n, n))
262            T1 = torch.zeros((n, n))
263            T2 = torch.zeros((n, n))
264            T3 = torch.zeros((n, n))
265            T4 = torch.zeros((n, n))
266            T5 = torch.zeros((n, n))
267            T6 = torch.zeros((n, n))
268            T7 = torch.zeros((n, n))
269            T8 = torch.zeros((n, n))
270            T9 = torch.zeros((n, n))
271            T10 = torch.zeros((n, n))
272            torch.relu_(A)
273            result = {}
274            result["T1"] = torch.mm(A, T1)
275            result["T2"] = torch.mm(A, T2)
276            result["T3"] = torch.mm(A, T3)
277            result["T4"] = torch.mm(A, T4)
278            result["T5"] = torch.mm(A, T5)
279            result["T6"] = torch.mm(A, T6)
280            result["T7"] = torch.mm(A, T7)
281            result["T8"] = torch.mm(A, T8)
282            result["T9"] = torch.mm(A, T9)
283            result["T10"] = torch.mm(A, T10)
284            return result
285
286        FileCheck().check_count("aten::mm", 10, exactly=True).run(test_batch_mm.graph)
287        self.run_pass("batch_mm", test_batch_mm.graph)
288        FileCheck().check_count("aten::mm", 10, exactly=True).check_not(
289            "prim::MMBatchSide"
290        ).run(test_batch_mm.graph)
291