1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10import torch 11from executorch.backends.arm.test import common 12from executorch.backends.arm.test.tester.arm_tester import ArmTester 13from executorch.exir.backend.backend_details import CompileSpec 14from parameterized import parameterized 15 16test_data_sute = [ 17 # (test_name, input, other,) See torch.mul() for info 18 ( 19 "op_mul_rank1_ones", 20 torch.ones(5), 21 torch.ones(5), 22 ), 23 ( 24 "op_mul_rank2_rand", 25 torch.rand(4, 5), 26 torch.rand(1, 5), 27 ), 28 ( 29 "op_mul_rank3_randn", 30 torch.randn(10, 5, 2), 31 torch.randn(10, 5, 2), 32 ), 33 ( 34 "op_mul_rank4_randn", 35 torch.randn(5, 10, 25, 20), 36 torch.randn(5, 10, 25, 20), 37 ), 38 ( 39 "op_mul_rank4_ones_mul_negative", 40 torch.ones(1, 10, 25, 20), 41 (-1) * torch.ones(5, 10, 25, 20), 42 ), 43 ( 44 "op_mul_rank4_negative_large_rand", 45 (-200) * torch.rand(5, 10, 25, 20), 46 torch.rand(5, 1, 1, 20), 47 ), 48 ( 49 "op_mul_rank4_large_randn", 50 200 * torch.randn(5, 10, 25, 20), 51 torch.rand(5, 10, 25, 1), 52 ), 53] 54 55 56class TestMul(unittest.TestCase): 57 class Mul(torch.nn.Module): 58 59 def forward( 60 self, 61 input_: torch.Tensor, 62 other_: torch.Tensor, 63 ): 64 return input_ * other_ 65 66 def _test_mul_tosa_MI_pipeline( 67 self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] 68 ): 69 ( 70 ArmTester( 71 module, 72 example_inputs=test_data, 73 compile_spec=common.get_tosa_compile_spec( 74 "TOSA-0.80.0+MI", permute_memory_to_nhwc=True 75 ), 76 ) 77 .export() 78 .check_count({"torch.ops.aten.mul.Tensor": 1}) 79 .check_not(["torch.ops.quantized_decomposed"]) 80 .to_edge() 81 .partition() 82 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 83 .to_executorch() 84 .run_method_and_compare_outputs(inputs=test_data) 85 ) 86 87 def _test_mul_tosa_BI_pipeline( 88 self, module: torch.nn.Module, test_data: tuple[torch.Tensor, torch.Tensor] 89 ): 90 ( 91 ArmTester( 92 module, 93 example_inputs=test_data, 94 compile_spec=common.get_tosa_compile_spec( 95 "TOSA-0.80.0+BI", permute_memory_to_nhwc=True 96 ), 97 ) 98 .quantize() 99 .export() 100 .check_count({"torch.ops.aten.mul.Tensor": 1}) 101 .check(["torch.ops.quantized_decomposed"]) 102 .to_edge() 103 .partition() 104 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 105 .to_executorch() 106 .run_method_and_compare_outputs(inputs=test_data, qtol=1.0) 107 ) 108 109 def _test_mul_ethosu_BI_pipeline( 110 self, 111 compile_spec: CompileSpec, 112 module: torch.nn.Module, 113 test_data: tuple[torch.Tensor, torch.Tensor], 114 ): 115 ( 116 ArmTester( 117 module, 118 example_inputs=test_data, 119 compile_spec=compile_spec, 120 ) 121 .quantize() 122 .export() 123 .check_count({"torch.ops.aten.mul.Tensor": 1}) 124 .check(["torch.ops.quantized_decomposed"]) 125 .to_edge() 126 .partition() 127 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 128 .to_executorch() 129 ) 130 131 @parameterized.expand(test_data_sute) 132 def test_mul_tosa_MI( 133 self, 134 test_name: str, 135 input_: torch.Tensor, 136 other_: torch.Tensor, 137 ): 138 test_data = (input_, other_) 139 self._test_mul_tosa_MI_pipeline(self.Mul(), test_data) 140 141 @parameterized.expand(test_data_sute) 142 def test_mul_tosa_BI( 143 self, 144 test_name: str, 145 input_: torch.Tensor, 146 other_: torch.Tensor, 147 ): 148 149 test_data = (input_, other_) 150 self._test_mul_tosa_BI_pipeline(self.Mul(), test_data) 151 152 @parameterized.expand(test_data_sute) 153 def test_mul_u55_BI( 154 self, 155 test_name: str, 156 input_: torch.Tensor, 157 other_: torch.Tensor, 158 ): 159 test_data = (input_, other_) 160 self._test_mul_ethosu_BI_pipeline( 161 common.get_u55_compile_spec(), self.Mul(), test_data 162 ) 163 164 @parameterized.expand(test_data_sute) 165 def test_mul_u85_BI( 166 self, 167 test_name: str, 168 input_: torch.Tensor, 169 other_: torch.Tensor, 170 ): 171 test_data = (input_, other_) 172 self._test_mul_ethosu_BI_pipeline( 173 common.get_u85_compile_spec(), self.Mul(), test_data 174 ) 175