1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9from typing import Tuple 10 11import torch 12from executorch.backends.arm.test import common 13from executorch.backends.arm.test.tester.arm_tester import ArmTester 14from executorch.exir.backend.compile_spec_schema import CompileSpec 15from parameterized import parameterized 16 17torch.manual_seed(1) 18 19 20class TestBMM(unittest.TestCase): 21 """Tests Batch MatMul""" 22 23 class BMM(torch.nn.Module): 24 test_parameters = [ 25 (torch.rand(5, 3, 5), torch.rand(5, 5, 2)), 26 (torch.rand(2, 1, 1), torch.rand(2, 1, 1)), 27 (torch.ones(1, 55, 3), torch.ones(1, 3, 44)), 28 (10000 * torch.randn(10, 1, 10), torch.randn(10, 10, 5)), 29 (-10 * torch.randn(2, 32, 64), 5 + 5 * torch.randn(2, 64, 32)), 30 ] 31 32 def forward(self, x, y): 33 return torch.bmm(x, y) 34 35 class MatMul(torch.nn.Module): 36 test_parameters = [(torch.rand(2, 3, 5), torch.rand(2, 5, 2))] 37 38 def forward(self, x, y): 39 return torch.matmul(x, y) 40 41 class BMMSingleInput(torch.nn.Module): 42 test_parameters = [ 43 (torch.rand(20, 3, 3),), 44 (torch.ones(2, 128, 128),), 45 (10000 * torch.randn(4, 25, 25),), 46 (5 + 5 * torch.randn(3, 64, 64),), 47 ] 48 49 def forward(self, x): 50 return torch.bmm(x, x) 51 52 def _test_bmm_tosa_MI_pipeline( 53 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, ...] 54 ): 55 ( 56 ArmTester( 57 module, 58 example_inputs=test_data, 59 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 60 ) 61 .export() 62 .check_not(["torch.ops.quantized_decomposed"]) 63 .to_edge() 64 .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) 65 .partition() 66 .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) 67 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 68 .to_executorch() 69 .run_method_and_compare_outputs(inputs=test_data) 70 ) 71 72 def _test_bmm_tosa_BI_pipeline( 73 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor, ...] 74 ): 75 ( 76 ArmTester( 77 module, 78 example_inputs=test_data, 79 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 80 ) 81 .quantize() 82 .export() 83 .check(["torch.ops.quantized_decomposed"]) 84 .to_edge() 85 .check_count({"executorch_exir_dialects_edge__ops_aten_bmm_default": 1}) 86 .partition() 87 .check_not(["executorch_exir_dialects_edge__ops_aten_bmm_default"]) 88 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 89 .to_executorch() 90 .run_method_and_compare_outputs(inputs=test_data) 91 ) 92 93 def _test_bmm_ethosu_BI_pipeline( 94 self, 95 module: torch.nn.Module, 96 compile_spec: CompileSpec, 97 test_data: Tuple[torch.Tensor, ...], 98 ): 99 ( 100 ArmTester( 101 module, 102 example_inputs=test_data, 103 compile_spec=compile_spec, 104 ) 105 .quantize() 106 .export() 107 .check_count({"torch.ops.aten.bmm.default": 1}) 108 .check(["torch.ops.quantized_decomposed"]) 109 .to_edge() 110 .partition() 111 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 112 .to_executorch() 113 ) 114 115 @parameterized.expand(BMM.test_parameters) 116 def test_bmm_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): 117 test_data = (operand1, operand2) 118 self._test_bmm_tosa_MI_pipeline(self.BMM(), test_data) 119 120 @parameterized.expand(BMMSingleInput.test_parameters) 121 def test_bmm_single_input_tosa_MI(self, operand1: torch.Tensor): 122 test_data = (operand1,) 123 self._test_bmm_tosa_MI_pipeline(self.BMMSingleInput(), test_data) 124 125 @parameterized.expand(MatMul.test_parameters) 126 def test_matmul_tosa_MI(self, operand1: torch.Tensor, operand2: torch.Tensor): 127 test_data = (operand1, operand2) 128 self._test_bmm_tosa_MI_pipeline(self.MatMul(), test_data) 129 130 @parameterized.expand(MatMul.test_parameters) 131 def test_matmul_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): 132 test_data = (operand1, operand2) 133 self._test_bmm_tosa_BI_pipeline(self.MatMul(), test_data) 134 135 @parameterized.expand(BMM.test_parameters) 136 def test_bmm_tosa_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): 137 test_data = (operand1, operand2) 138 self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data) 139 140 @parameterized.expand(BMMSingleInput.test_parameters) 141 def test_bmm_single_input_tosa_BI(self, operand1: torch.Tensor): 142 test_data = (operand1,) 143 self._test_bmm_tosa_BI_pipeline(self.BMMSingleInput(), test_data) 144 145 @parameterized.expand(BMM.test_parameters) 146 def test_bmm_u55_BI(self, operand1: torch.Tensor, operand2: torch.Tensor): 147 test_data = (operand1, operand2) 148 self._test_bmm_tosa_BI_pipeline(self.BMM(), test_data) 149 150 # Expected to fail with error: Warning, unsupported fusing of TOSA Rescale previous operator is of type: Memcpy 151 @parameterized.expand(BMMSingleInput.test_parameters) 152 @unittest.expectedFailure 153 def test_bmm_single_input_u55_BI(self, operand1: torch.Tensor): 154 test_data = (operand1,) 155 self._test_bmm_ethosu_BI_pipeline( 156 self.BMMSingleInput(), common.get_u55_compile_spec(), test_data 157 ) 158 159 @parameterized.expand(BMMSingleInput.test_parameters) 160 def test_bmm_single_input_u85_BI(self, operand1: torch.Tensor): 161 test_data = (operand1,) 162 self._test_bmm_ethosu_BI_pipeline( 163 self.BMMSingleInput(), common.get_u85_compile_spec(), test_data 164 ) 165