1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10import torch 11 12from executorch.backends.arm.test import common 13from executorch.backends.arm.test.tester.arm_tester import ArmTester 14from executorch.exir.backend.compile_spec_schema import CompileSpec 15from parameterized import parameterized 16 17test_data_t = tuple[torch.Tensor, int, int] 18 19test_data_suite: list[tuple[test_data_t]] = [ 20 # (test_data, dim, index) 21 ((torch.zeros(5, 3, 20), -1, 0),), 22 ((torch.zeros(5, 3, 20), 0, -1),), 23 ((torch.zeros(5, 3, 20), 0, 4),), 24 ((torch.ones(10, 10, 10), 0, 2),), 25 ((torch.rand(5, 3, 20, 2), 0, 2),), 26 ((torch.rand(10, 10) - 0.5, 0, 0),), 27 ((torch.randn(10) + 10, 0, 1),), 28 ((torch.randn(10) - 10, 0, 2),), 29 ((torch.arange(-16, 16, 0.2), 0, 1),), 30] 31 32 33class TestSelect(unittest.TestCase): 34 class SelectCopy(torch.nn.Module): 35 def __init__(self): 36 super().__init__() 37 38 def forward(self, x, dim: int, index: int): 39 return torch.select_copy(x, dim=dim, index=index) 40 41 class SelectInt(torch.nn.Module): 42 def __init__(self): 43 super().__init__() 44 45 def forward(self, x, dim: int, index: int): 46 return torch.select(x, dim=dim, index=index) 47 48 def _test_select_tosa_MI_pipeline( 49 self, 50 module: torch.nn.Module, 51 test_data: test_data_t, 52 export_target: str, 53 ): 54 # For 4D tensors, do not permute to NHWC 55 permute = False if len(test_data[0].shape) == 4 else True 56 ( 57 ArmTester( 58 module, 59 example_inputs=test_data, 60 compile_spec=common.get_tosa_compile_spec( 61 "TOSA-0.80.0+MI", permute_memory_to_nhwc=permute 62 ), 63 ) 64 .export() 65 .check([export_target]) 66 .check_not(["torch.ops.quantized_decomposed"]) 67 .to_edge() 68 .partition() 69 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 70 .to_executorch() 71 .run_method_and_compare_outputs(inputs=test_data) 72 ) 73 74 def _test_select_tosa_BI_pipeline( 75 self, 76 module: torch.nn.Module, 77 test_data: test_data_t, 78 export_target: str, 79 ): 80 # For 4D tensors, do not permute to NHWC 81 permute = False if len(test_data[0].shape) == 4 else True 82 ( 83 ArmTester( 84 module, 85 example_inputs=test_data, 86 compile_spec=common.get_tosa_compile_spec( 87 "TOSA-0.80.0+BI", permute_memory_to_nhwc=permute 88 ), 89 ) 90 .quantize() 91 .export() 92 .check([export_target]) 93 .check(["torch.ops.quantized_decomposed"]) 94 .to_edge() 95 .partition() 96 .dump_artifact() 97 .dump_operator_distribution() 98 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 99 .to_executorch() 100 .run_method_and_compare_outputs(inputs=test_data) 101 ) 102 103 def _test_select_ethos_BI_pipeline( 104 self, 105 compile_spec: list[CompileSpec], 106 module: torch.nn.Module, 107 test_data: test_data_t, 108 export_target: str, 109 ): 110 ( 111 ArmTester( 112 module, 113 example_inputs=test_data, 114 compile_spec=compile_spec, 115 ) 116 .quantize() 117 .export() 118 .check([export_target]) 119 .check(["torch.ops.quantized_decomposed"]) 120 .to_edge() 121 .partition() 122 .dump_artifact() 123 .dump_operator_distribution() 124 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 125 .to_executorch() 126 ) 127 128 def _test_select_tosa_u55_BI_pipeline( 129 self, module: torch.nn.Module, test_data: test_data_t, export_target: str 130 ): 131 # For 4D tensors, do not permute to NHWC 132 permute = False if len(test_data[0].shape) == 4 else True 133 self._test_select_ethos_BI_pipeline( 134 common.get_u55_compile_spec(permute_memory_to_nhwc=permute), 135 module, 136 test_data, 137 export_target, 138 ) 139 140 def _test_select_tosa_u85_BI_pipeline( 141 self, module: torch.nn.Module, test_data: test_data_t, export_target: str 142 ): 143 # For 4D tensors, do not permute to NHWC 144 permute = False if len(test_data[0].shape) == 4 else True 145 self._test_select_ethos_BI_pipeline( 146 common.get_u85_compile_spec(permute_memory_to_nhwc=permute), 147 module, 148 test_data, 149 export_target, 150 ) 151 152 @parameterized.expand(test_data_suite) 153 def test_select_copy_tosa_MI(self, test_data: test_data_t): 154 self._test_select_tosa_MI_pipeline( 155 self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" 156 ) 157 158 @parameterized.expand(test_data_suite) 159 def test_select_int_tosa_MI(self, test_data: test_data_t): 160 self._test_select_tosa_MI_pipeline( 161 self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" 162 ) 163 164 @parameterized.expand(test_data_suite) 165 def test_select_copy_tosa_BI(self, test_data: test_data_t): 166 self._test_select_tosa_BI_pipeline( 167 self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" 168 ) 169 170 @parameterized.expand(test_data_suite) 171 def test_select_int_tosa_BI(self, test_data: test_data_t): 172 self._test_select_tosa_BI_pipeline( 173 self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" 174 ) 175 176 @parameterized.expand(test_data_suite) 177 def test_select_copy_tosa_u55_BI(self, test_data: test_data_t): 178 self._test_select_tosa_u55_BI_pipeline( 179 self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" 180 ) 181 182 @parameterized.expand(test_data_suite) 183 def test_select_int_tosa_u55_BI(self, test_data: test_data_t): 184 self._test_select_tosa_u55_BI_pipeline( 185 self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" 186 ) 187 188 @parameterized.expand(test_data_suite) 189 def test_select_copy_tosa_u85_BI(self, test_data: test_data_t): 190 self._test_select_tosa_u85_BI_pipeline( 191 self.SelectCopy(), test_data, export_target="torch.ops.aten.select_copy.int" 192 ) 193 194 @parameterized.expand(test_data_suite) 195 def test_select_int_tosa_u85_BI(self, test_data: test_data_t): 196 self._test_select_tosa_u85_BI_pipeline( 197 self.SelectInt(), test_data, export_target="torch.ops.aten.select.int" 198 ) 199