xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_split.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10
11from executorch.backends.arm.test import common
12from executorch.backends.arm.test.tester.arm_tester import ArmTester
13from executorch.exir.backend.compile_spec_schema import CompileSpec
14from parameterized import parameterized
15
16test_data_t = tuple[torch.Tensor, int | list[int], int]
17
18
19class TestSimpleSplit(unittest.TestCase):
20    class Split(torch.nn.Module):
21
22        test_data: list[tuple[test_data_t]] = [
23            ((torch.rand(10), 2, 0),),
24            ((torch.rand(10, 10), 3, 1),),
25            ((torch.rand(10, 10), 4, -1),),
26            ((torch.rand(10, 15, 10), [2, 2, 11], 1),),
27            ((torch.rand(4, 4, 4, 4), 2, 0),),
28            ((torch.rand(4, 4, 4, 4), [1, 1, 1, 1], -2),),
29        ]
30
31        def forward(
32            self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
33        ):
34            return x.split(split_size=split_size_or_sections, dim=dim)
35
36    class SplitWithSizes(torch.nn.Module):
37        def forward(self, x: torch.Tensor, split_sizes: list[int], dim: int):
38            return x.split_with_sizes(split_sizes=split_sizes, dim=dim)
39
40    class SplitSingleOut(torch.nn.Module):
41        def forward(
42            self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
43        ):
44            return x.split(split_size=split_size_or_sections, dim=dim)[1]
45
46    class SplitTwoOut(torch.nn.Module):
47        def forward(
48            self, x: torch.Tensor, split_size_or_sections: int | list[int], dim: int
49        ):
50            return x.split(split_size=split_size_or_sections, dim=dim)[1:3]
51
52    def _test_split_tosa_MI_pipeline(
53        self, module: torch.nn.Module, test_data: test_data_t
54    ):
55        (
56            ArmTester(
57                module,
58                example_inputs=test_data,
59                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
60            )
61            .export()
62            .to_edge()
63            .check(
64                [
65                    "executorch_exir_dialects_edge__ops_aten_split_with_sizes_copy_default"
66                ]
67            )
68            .partition()
69            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
70            .to_executorch()
71            .run_method_and_compare_outputs(inputs=test_data)
72        )
73
74    def _test_split_tosa_BI_pipeline(
75        self, module: torch.nn.Module, test_data: test_data_t
76    ):
77
78        (
79            ArmTester(
80                module,
81                example_inputs=test_data,
82                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
83            )
84            .quantize()
85            .export()
86            .to_edge()
87            .partition()
88            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
89            .to_executorch()
90            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
91        )
92
93    def _test_split_ethosu_BI_pipeline(
94        self, compile_spec: CompileSpec, module: torch.nn.Module, test_data: test_data_t
95    ):
96        (
97            ArmTester(
98                module,
99                example_inputs=test_data,
100                compile_spec=compile_spec,
101            )
102            .quantize()
103            .export()
104            .check(["torch.ops.aten.split.Tensor"])
105            .to_edge()
106            .partition()
107            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
108            .to_executorch()
109        )
110
111    @parameterized.expand(Split.test_data)
112    def test_split_tosa_MI(self, test_data: test_data_t):
113        self._test_split_tosa_MI_pipeline(self.Split(), test_data)
114
115    @parameterized.expand([Split.test_data[3], Split.test_data[5]])
116    def test_split_with_sizes_tosa_MI(self, test_data: test_data_t):
117        assert isinstance(test_data[1], list)
118        self._test_split_tosa_MI_pipeline(self.SplitWithSizes(), test_data)
119
120    @parameterized.expand(Split.test_data)
121    def test_split_one_out_tosa_MI(self, test_data: test_data_t):
122        self._test_split_tosa_MI_pipeline(self.SplitSingleOut(), test_data)
123
124    @parameterized.expand(Split.test_data)
125    def test_split_two_out_tosa_MI(self, test_data: test_data_t):
126        self._test_split_tosa_MI_pipeline(self.SplitTwoOut(), test_data)
127
128    @parameterized.expand(Split.test_data)
129    def test_split_tosa_BI(self, test_data: test_data_t):
130        self._test_split_tosa_BI_pipeline(self.Split(), test_data)
131
132    @parameterized.expand(
133        [Split.test_data[0], Split.test_data[1], Split.test_data[2], Split.test_data[4]]
134    )
135    def test_split_u55_BI(self, test_data: test_data_t):
136        self._test_split_ethosu_BI_pipeline(
137            common.get_u55_compile_spec(), self.Split(), test_data
138        )
139
140    # TODO MLETORCH-350
141    @parameterized.expand([Split.test_data[3], Split.test_data[5]])
142    @unittest.expectedFailure
143    def test_split_u55_BI_skip(self, test_data: test_data_t):
144        self._test_split_ethosu_BI_pipeline(
145            common.get_u55_compile_spec(), self.Split(), test_data
146        )
147
148    @parameterized.expand(
149        [Split.test_data[0], Split.test_data[1], Split.test_data[2], Split.test_data[4]]
150    )
151    def test_split_u85_BI(self, test_data: test_data_t):
152        self._test_split_ethosu_BI_pipeline(
153            common.get_u85_compile_spec(), self.Split(), test_data
154        )
155
156    @parameterized.expand([Split.test_data[3], Split.test_data[5]])
157    @unittest.expectedFailure
158    def test_split_u85_BI_skip(self, test_data: test_data_t):
159        self._test_split_ethosu_BI_pipeline(
160            common.get_u85_compile_spec(), self.Split(), test_data
161        )
162