1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import unittest 8 9import torch 10 11from executorch.backends.arm.test import common 12from executorch.backends.arm.test.tester.arm_tester import ArmTester 13from executorch.exir.backend.compile_spec_schema import CompileSpec 14from parameterized import parameterized 15 16test_data_t = tuple[torch.Tensor, int | list[int], int] 17 18 19class TestSimpleSplit(unittest.TestCase): 20 class Split(torch.nn.Module): 21 22 test_data: list[tuple[test_data_t]] = [ 23 ((torch.rand(10), 2, 0),), 24 ((torch.rand(10, 10), 3, 1),), 25 ((torch.rand(10, 10), 4, -1),), 26 ((torch.rand(10, 15, 10), [2, 2, 11], 1),), 27 ((torch.rand(4, 4, 4, 4), 2, 0),), 28 ((torch.rand(4, 4, 4, 4), [1, 1, 1, 1], -2),), 29 ] 30 31 def forward( 32 self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int 33 ): 34 return x.split(split_size=split_size_or_sections, dim=dim) 35 36 class SplitWithSizes(torch.nn.Module): 37 def forward(self, x: torch.Tensor, split_sizes: list[int], dim: int): 38 return x.split_with_sizes(split_sizes=split_sizes, dim=dim) 39 40 class SplitSingleOut(torch.nn.Module): 41 def forward( 42 self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int 43 ): 44 return x.split(split_size=split_size_or_sections, dim=dim)[1] 45 46 class SplitTwoOut(torch.nn.Module): 47 def forward( 48 self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int 49 ): 50 return x.split(split_size=split_size_or_sections, dim=dim)[1:3] 51 52 def _test_split_tosa_MI_pipeline( 53 self, module: torch.nn.Module, test_data: test_data_t 54 ): 55 ( 56 ArmTester( 57 module, 58 example_inputs=test_data, 59 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 60 ) 61 .export() 62 .to_edge() 63 .check( 64 [ 65 "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default" 66 ] 67 ) 68 .partition() 69 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 70 .to_executorch() 71 .run_method_and_compare_outputs(inputs=test_data) 72 ) 73 74 def _test_split_tosa_BI_pipeline( 75 self, module: torch.nn.Module, test_data: test_data_t 76 ): 77 78 ( 79 ArmTester( 80 module, 81 example_inputs=test_data, 82 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 83 ) 84 .quantize() 85 .export() 86 .to_edge() 87 .partition() 88 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 89 .to_executorch() 90 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 91 ) 92 93 def _test_split_ethosu_BI_pipeline( 94 self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: test_data_t 95 ): 96 ( 97 ArmTester( 98 module, 99 example_inputs=test_data, 100 compile_spec=compile_spec, 101 ) 102 .quantize() 103 .export() 104 .check(["torch.ops.aten.split.Tensor"]) 105 .to_edge() 106 .partition() 107 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 108 .to_executorch() 109 ) 110 111 @parameterized.expand(Split.test_data) 112 def test_split_tosa_MI(self, test_data: test_data_t): 113 self._test_split_tosa_MI_pipeline(self.Split(), test_data) 114 115 @parameterized.expand([Split.test_data[3], Split.test_data[5]]) 116 def test_split_with_sizes_tosa_MI(self, test_data: test_data_t): 117 assert isinstance(test_data[1], list) 118 self._test_split_tosa_MI_pipeline(self.SplitWithSizes(), test_data) 119 120 @parameterized.expand(Split.test_data) 121 def test_split_one_out_tosa_MI(self, test_data: test_data_t): 122 self._test_split_tosa_MI_pipeline(self.SplitSingleOut(), test_data) 123 124 @parameterized.expand(Split.test_data) 125 def test_split_two_out_tosa_MI(self, test_data: test_data_t): 126 self._test_split_tosa_MI_pipeline(self.SplitTwoOut(), test_data) 127 128 @parameterized.expand(Split.test_data) 129 def test_split_tosa_BI(self, test_data: test_data_t): 130 self._test_split_tosa_BI_pipeline(self.Split(), test_data) 131 132 @parameterized.expand( 133 [Split.test_data[0], Split.test_data[1], Split.test_data[2], Split.test_data[4]] 134 ) 135 def test_split_u55_BI(self, test_data: test_data_t): 136 self._test_split_ethosu_BI_pipeline( 137 common.get_u55_compile_spec(), self.Split(), test_data 138 ) 139 140 # TODO MLETORCH-350 141 @parameterized.expand([Split.test_data[3], Split.test_data[5]]) 142 @unittest.expectedFailure 143 def test_split_u55_BI_skip(self, test_data: test_data_t): 144 self._test_split_ethosu_BI_pipeline( 145 common.get_u55_compile_spec(), self.Split(), test_data 146 ) 147 148 @parameterized.expand( 149 [Split.test_data[0], Split.test_data[1], Split.test_data[2], Split.test_data[4]] 150 ) 151 def test_split_u85_BI(self, test_data: test_data_t): 152 self._test_split_ethosu_BI_pipeline( 153 common.get_u85_compile_spec(), self.Split(), test_data 154 ) 155 156 @parameterized.expand([Split.test_data[3], Split.test_data[5]]) 157 @unittest.expectedFailure 158 def test_split_u85_BI_skip(self, test_data: test_data_t): 159 self._test_split_ethosu_BI_pipeline( 160 common.get_u85_compile_spec(), self.Split(), test_data 161 ) 162