xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_view.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7#
8# Tests the view op which changes the size of a Tensor without changing the underlying data.
9#
10
11import unittest
12from typing import Tuple
13
14import torch
15
16from executorch.backends.arm.test import common
17from executorch.backends.arm.test.tester.arm_tester import ArmTester
18
19from executorch.exir.backend.compile_spec_schema import CompileSpec
20from parameterized import parameterized
21
22
23class TestView(unittest.TestCase):
24    """Tests the view operation."""
25
26    class View(torch.nn.Module):
27
28        needs_transpose_tests = [
29            (torch.rand(100), (1, -1, 5, 2)),
30            (torch.rand(10, 2, 1, 5), (1, -1, 5, 2)),
31            (torch.rand(1, 2, 1, 9), (3, 1, 3, 2)),
32            (torch.rand(2, 1, 1, 9), (3, 2, 3, 1)),
33            (torch.rand(2, 50, 2, 1), (1, 200)),
34            (torch.rand(2, 5, 2, 3), (1, 15, 4)),
35        ]
36
37        no_transpose_tests = [
38            (torch.rand(2, 1, 1, 9), (3, 1, 3, 2)),
39            (torch.rand(5, 10, 1, 1), (25, 2, 1, 1)),
40            (torch.rand(10, 2), (1, 1, 5, 4)),
41            (torch.rand(10, 10), (5, 1, 5, 4)),
42            (torch.rand(1, 1, 1, 10), (1, 1, 10, 1)),
43            (torch.rand(1, 1, 5, 10), (1, 1, 50, 1)),
44            (torch.rand(5, 10, 1, 1), (1, 25, 2)),
45            (torch.rand(2, 50, 1, 1), (1, 100)),
46        ]
47
48        def forward(self, x: torch.Tensor, new_shape):
49            return x.view(new_shape)
50
51    def _test_view_tosa_MI_pipeline(
52        self, module: torch.nn.Module, test_data: torch.Tensor
53    ):
54        (
55            ArmTester(
56                module,
57                example_inputs=test_data,
58                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
59            )
60            .export()
61            .check_count({"torch.ops.aten.view.default": 1})
62            .to_edge()
63            .partition()
64            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
65            .to_executorch()
66            .run_method_and_compare_outputs(inputs=test_data)
67        )
68
69    def _test_view_tosa_BI_pipeline(
70        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
71    ):
72        (
73            ArmTester(
74                module,
75                example_inputs=test_data,
76                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
77            )
78            .quantize()
79            .export()
80            .check_count({"torch.ops.aten.view.default": 1})
81            .to_edge()
82            .partition()
83            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
84            .to_executorch()
85            .run_method_and_compare_outputs(inputs=test_data, qtol=1)
86        )
87
88    def _test_view_ethos_BI_pipeline(
89        self,
90        compile_spec: list[CompileSpec],
91        module: torch.nn.Module,
92        test_data: Tuple[torch.Tensor],
93    ):
94        (
95            ArmTester(
96                module,
97                example_inputs=test_data,
98                compile_spec=compile_spec,
99            )
100            .quantize()
101            .export()
102            .check_count({"torch.ops.aten.view.default": 1})
103            .to_edge()
104            .partition()
105            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
106            .to_executorch()
107        )
108
109    def _test_view_u55_BI_pipeline(
110        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
111    ):
112        self._test_view_ethos_BI_pipeline(
113            common.get_u55_compile_spec(), module, test_data
114        )
115
116    def _test_view_u85_BI_pipeline(
117        self, module: torch.nn.Module, test_data: Tuple[torch.Tensor]
118    ):
119        self._test_view_ethos_BI_pipeline(
120            common.get_u85_compile_spec(), module, test_data
121        )
122
123    @parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
124    def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_shape):
125        self._test_view_tosa_MI_pipeline(self.View(), (test_tensor, new_shape))
126
127    @parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
128    def test_view_tosa_BI(self, test_tensor: torch.Tensor, new_shape):
129        self._test_view_tosa_BI_pipeline(self.View(), (test_tensor, new_shape))
130
131    @parameterized.expand(View.no_transpose_tests)
132    def test_view_u55_BI(self, test_tensor: torch.Tensor, new_shape):
133        self._test_view_u55_BI_pipeline(self.View(), (test_tensor, new_shape))
134
135    @parameterized.expand(View.needs_transpose_tests)
136    @unittest.expectedFailure
137    def test_view_transpose_u55_BI(self, test_tensor: torch.Tensor, new_shape):
138        self._test_view_u55_BI_pipeline(self.View(), (test_tensor, new_shape))
139
140    @parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests)
141    def test_view_u85_BI(self, test_tensor: torch.Tensor, new_shape):
142        self._test_view_u85_BI_pipeline(self.View(), (test_tensor, new_shape))
143