xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_cat.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10from typing import Tuple
11
12import torch
13from executorch.backends.arm.test import common
14
15from executorch.backends.arm.test.tester.arm_tester import ArmTester
16from executorch.exir.backend.compile_spec_schema import CompileSpec
17from parameterized import parameterized
18
19
20class TestCat(unittest.TestCase):
21
22    class Cat(torch.nn.Module):
23        test_parameters = [
24            ((torch.ones(1), torch.ones(1)), 0),
25            ((torch.ones(1, 2), torch.randn(1, 5), torch.randn(1, 1)), 1),
26            (
27                (
28                    torch.ones(1, 2, 5),
29                    torch.randn(1, 2, 4),
30                    torch.randn(1, 2, 2),
31                    torch.randn(1, 2, 1),
32                ),
33                -1,
34            ),
35            ((torch.randn(2, 2, 4, 4), torch.randn(2, 2, 4, 1)), 3),
36            (
37                (
38                    10000 * torch.randn(2, 3, 1, 4),
39                    torch.randn(2, 7, 1, 4),
40                    torch.randn(2, 1, 1, 4),
41                ),
42                -3,
43            ),
44        ]
45
46        def __init__(self):
47            super().__init__()
48
49        def forward(self, tensors: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor:
50            return torch.cat(tensors, dim=dim)
51
52    def _test_cat_tosa_MI_pipeline(
53        self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int]
54    ):
55        (
56            ArmTester(
57                module,
58                example_inputs=test_data,
59                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
60            )
61            .export()
62            .check_count({"torch.ops.aten.cat.default": 1})
63            .check_not(["torch.ops.quantized_decomposed"])
64            .to_edge()
65            .partition()
66            .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"])
67            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
68            .to_executorch()
69            .run_method_and_compare_outputs(inputs=test_data)
70        )
71
72    def _test_cat_tosa_BI_pipeline(
73        self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int]
74    ):
75        (
76            ArmTester(
77                module,
78                example_inputs=test_data,
79                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
80            )
81            .quantize()
82            .export()
83            .check_count({"torch.ops.aten.cat.default": 1})
84            .check(["torch.ops.quantized_decomposed"])
85            .to_edge()
86            .partition()
87            .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"])
88            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
89            .to_executorch()
90            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
91        )
92
93    def _test_cat_ethosu_BI_pipeline(
94        self,
95        module: torch.nn.Module,
96        compile_spec: CompileSpec,
97        test_data: Tuple[tuple[torch.Tensor, ...], int],
98    ):
99        (
100            ArmTester(
101                module,
102                example_inputs=test_data,
103                compile_spec=compile_spec,
104            )
105            .quantize()
106            .export()
107            .check_count({"torch.ops.aten.cat.default": 1})
108            .check(["torch.ops.quantized_decomposed"])
109            .to_edge()
110            .partition()
111            .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"])
112            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
113            .to_executorch()
114        )
115
116    @parameterized.expand(Cat.test_parameters)
117    def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int):
118        test_data = (operands, dim)
119        self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
120
121    def test_cat_4d_tosa_MI(self):
122        square = torch.ones((2, 2, 2, 2))
123        for dim in range(-3, 3):
124            test_data = ((square, square), dim)
125            self._test_cat_tosa_MI_pipeline(self.Cat(), test_data)
126
127    @parameterized.expand(Cat.test_parameters)
128    def test_cat_tosa_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
129        test_data = (operands, dim)
130        self._test_cat_tosa_BI_pipeline(self.Cat(), test_data)
131
132    @parameterized.expand(Cat.test_parameters)
133    def test_cat_u55_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
134        test_data = (operands, dim)
135        self._test_cat_ethosu_BI_pipeline(
136            self.Cat(), common.get_u55_compile_spec(), test_data
137        )
138
139    @parameterized.expand(Cat.test_parameters)
140    def test_cat_u85_BI(self, operands: tuple[torch.Tensor, ...], dim: int):
141        test_data = (operands, dim)
142        self._test_cat_ethosu_BI_pipeline(
143            self.Cat(), common.get_u85_compile_spec(), test_data
144        )
145