xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_bmm.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9from typing import Tuple
10
11import torch
12from executorch.backends.arm.test import common
13from executorch.backends.arm.test.tester.arm_tester import ArmTester
14from executorch.exir.backend.compile_spec_schema import CompileSpec
15from parameterized import parameterized
16
17torch.manual_seed(1)
18
19
20class TestBMM(unittest.TestCase):
21    """Tests Batch MatMul"""
22
23    class BMM(torch.nn.Module):
24        test_parameters = [
25            (torch.rand(5, 3, 5), torch.rand(5, 5, 2)),
26            (torch.rand(2, 1, 1), torch.rand(2, 1, 1)),
27            (torch.ones(1, 55, 3), torch.ones(1, 3, 44)),
28            (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)),
29            (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)),
30        ]
31
32        def forward(self, x, y):
33            return torch.bmm(x, y)
34
35    class MatMul(torch.nn.Module):
36        test_parameters = [(torch.rand(2, 3, 5), torch.rand(2, 5, 2))]
37
38        def forward(self, x, y):
39            return torch.matmul(x, y)
40
41    class BMMSingleInput(torch.nn.Module):
42        test_parameters = [
43            (torch.rand(20, 3, 3),),
44            (torch.ones(2, 128, 128),),
45            (10000 * torch.randn(4, 25, 25),),
46            (5 + 5 * torch.randn(3, 64, 64),),
47        ]
48
49        def forward(self, x):
50            return torch.bmm(x, x)
51
52    def _test_bmm_tosa_MI_pipeline(
53        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, ...]
54    ):
55        (
56            ArmTester(
57                module,
58                example_inputs=test_data,
59                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
60            )
61            .export()
62            .check_not(["torch.ops.quantized_decomposed"])
63            .to_edge()
64            .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1})
65            .partition()
66            .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"])
67            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
68            .to_executorch()
69            .run_method_and_compare_outputs(inputs=test_data)
70        )
71
72    def _test_bmm_tosa_BI_pipeline(
73        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, ...]
74    ):
75        (
76            ArmTester(
77                module,
78                example_inputs=test_data,
79                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
80            )
81            .quantize()
82            .export()
83            .check(["torch.ops.quantized_decomposed"])
84            .to_edge()
85            .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1})
86            .partition()
87            .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"])
88            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
89            .to_executorch()
90            .run_method_and_compare_outputs(inputs=test_data)
91        )
92
93    def _test_bmm_ethosu_BI_pipeline(
94        self,
95        module: torch.nn.Module,
96        compile_spec: CompileSpec,
97        test_data: Tuple[torch.Tensor, ...],
98    ):
99        (
100            ArmTester(
101                module,
102                example_inputs=test_data,
103                compile_spec=compile_spec,
104            )
105            .quantize()
106            .export()
107            .check_count({"torch.ops.aten.bmm.default": 1})
108            .check(["torch.ops.quantized_decomposed"])
109            .to_edge()
110            .partition()
111            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
112            .to_executorch()
113        )
114
115    @parameterized.expand(BMM.test_parameters)
116    def test_bmm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
117        test_data = (operand1, operand2)
118        self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data)
119
120    @parameterized.expand(BMMSingleInput.test_parameters)
121    def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor):
122        test_data = (operand1,)
123        self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data)
124
125    @parameterized.expand(MatMul.test_parameters)
126    def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor):
127        test_data = (operand1, operand2)
128        self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data)
129
130    @parameterized.expand(MatMul.test_parameters)
131    def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
132        test_data = (operand1, operand2)
133        self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data)
134
135    @parameterized.expand(BMM.test_parameters)
136    def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
137        test_data = (operand1, operand2)
138        self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
139
140    @parameterized.expand(BMMSingleInput.test_parameters)
141    def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor):
142        test_data = (operand1,)
143        self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data)
144
145    @parameterized.expand(BMM.test_parameters)
146    def test_bmm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor):
147        test_data = (operand1, operand2)
148        self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data)
149
150    # Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy
151    @parameterized.expand(BMMSingleInput.test_parameters)
152    @unittest.expectedFailure
153    def test_bmm_single_input_u55_BI(self, operand1: torch.Tensor):
154        test_data = (operand1,)
155        self._test_bmm_ethosu_BI_pipeline(
156            self.BMMSingleInput(), common.get_u55_compile_spec(), test_data
157        )
158
159    @parameterized.expand(BMMSingleInput.test_parameters)
160    def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor):
161        test_data = (operand1,)
162        self._test_bmm_ethosu_BI_pipeline(
163            self.BMMSingleInput(), common.get_u85_compile_spec(), test_data
164        )
165