xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_mm.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import logging
8import unittest
9
10from typing import Tuple
11
12import torch
13from executorch.backends.arm.test import common
14from executorch.backends.arm.test.tester.arm_tester import ArmTester
15from executorch.exir.backend.backend_details import CompileSpec
16from parameterized import parameterized
17
18logger = logging.getLogger(__name__)
19logger.setLevel(logging.INFO)
20
21torch.manual_seed(0)
22
23
24class TestMM(unittest.TestCase):
25    """Tests MatMul"""
26
27    class MM(torch.nn.Module):
28        test_parameters = [
29            (torch.rand(3, 5), torch.rand(5, 2)),
30            (torch.rand(1, 1), torch.rand(1, 1)),
31            (torch.ones(55, 3), torch.ones(3, 44)),
32            (10000 * torch.randn(1, 10), torch.randn(10, 5)),
33            (-10 * torch.randn(32, 64), 5 + 5 * torch.randn(64, 32)),
34        ]
35
36        def forward(self, x, y):
37            return torch.mm(x, y)
38
39    class MMSingleInput(torch.nn.Module):
40        test_parameters = [
41            (torch.rand(3, 3),),
42            (torch.ones(128, 128),),
43            (10000 * torch.randn(25, 25),),
44            (5 + 5 * torch.randn(64, 64),),
45        ]
46
47        def forward(self, x):
48            return torch.mm(x, x)
49
50    def _test_mm_tosa_MI_pipeline(
51        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
52    ):
53        (
54            ArmTester(
55                module,
56                example_inputs=test_data,
57                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
58            )
59            .export()
60            .check_count({"torch.ops.aten.mm.default": 1})
61            .check_not(["torch.ops.quantized_decomposed"])
62            .to_edge()
63            .partition()
64            .check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"])
65            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
66            .to_executorch()
67            .run_method_and_compare_outputs(inputs=test_data)
68        )
69
70    def _test_mm_tosa_BI_pipeline(
71        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
72    ):
73        (
74            ArmTester(
75                module,
76                example_inputs=test_data,
77                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
78            )
79            .quantize()
80            .export()
81            .check_count({"torch.ops.aten.mm.default": 1})
82            .check(["torch.ops.quantized_decomposed"])
83            .to_edge()
84            .partition()
85            .check_not(["executorch_exir_dialects_edge__ops_aten_mm_default"])
86            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
87            .to_executorch()
88            .run_method_and_compare_outputs(inputs=test_data)
89        )
90
91    def _test_mm_ethosu_BI_pipeline(
92        self,
93        compile_spec: CompileSpec,
94        module: torch.nn.Module,
95        test_data: Tuple[torch.Tensor],
96    ):
97        (
98            ArmTester(
99                module,
100                example_inputs=test_data,
101                compile_spec=compile_spec,
102            )
103            .quantize()
104            .export()
105            .check_count({"torch.ops.aten.mm.default": 1})
106            .check(["torch.ops.quantized_decomposed"])
107            .to_edge()
108            .partition()
109            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
110            .to_executorch()
111        )
112
113    @parameterized.expand(MM.test_parameters)
114    def test_mm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
115        test_data = (operand1, operand2)
116        self._test_mm_tosa_MI_pipeline(self.MM(), test_data)
117
118    @parameterized.expand(MMSingleInput.test_parameters)
119    def test_mm_single_input_tosa_MI(self, operand1: torch.Tensor):
120        test_data = (operand1,)
121        self._test_mm_tosa_MI_pipeline(self.MMSingleInput(), test_data)
122
123    @parameterized.expand(MM.test_parameters)
124    def test_mm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
125        test_data = (operand1, operand2)
126        self._test_mm_tosa_BI_pipeline(self.MM(), test_data)
127
128    @parameterized.expand(MMSingleInput.test_parameters)
129    def test_mm_single_input_tosa_BI(self, operand1: torch.Tensor):
130        test_data = (operand1,)
131        self._test_mm_tosa_BI_pipeline(self.MMSingleInput(), test_data)
132
133    # Expected to fail with error: CPU performance estimation for "MatMul" not implemented
134    @parameterized.expand(MM.test_parameters)
135    @unittest.expectedFailure
136    def test_mm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
137        test_data = (operand1, operand2)
138        self._test_mm_ethosu_BI_pipeline(
139            common.get_u55_compile_spec(), self.MM(), test_data
140        )
141
142    # Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
143    @parameterized.expand(MMSingleInput.test_parameters)
144    @unittest.expectedFailure
145    def test_mm_single_input_u55_BI(self, operand1: torch.Tensor):
146        test_data = (operand1,)
147        self._test_mm_ethosu_BI_pipeline(
148            common.get_u55_compile_spec(), self.MMSingleInput(), test_data
149        )
150
151    @parameterized.expand(MM.test_parameters)
152    def test_mm_u85_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
153        test_data = (operand1, operand2)
154        self._test_mm_ethosu_BI_pipeline(
155            common.get_u85_compile_spec(), self.MM(), test_data
156        )
157
158    @parameterized.expand(MMSingleInput.test_parameters)
159    def test_mm_single_input_u85_BI(self, operand1: torch.Tensor):
160        test_data = (operand1,)
161        self._test_mm_ethosu_BI_pipeline(
162            common.get_u85_compile_spec(), self.MMSingleInput(), test_data
163        )
164