1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# 8# Tests the view op which changes the size of a Tensor without changing the underlying data. 9# 10 11import unittest 12from typing import Tuple 13 14import torch 15 16from executorch.backends.arm.test import common 17from executorch.backends.arm.test.tester.arm_tester import ArmTester 18 19from executorch.exir.backend.compile_spec_schema import CompileSpec 20from parameterized import parameterized 21 22 23class TestView(unittest.TestCase): 24 """Tests the view operation.""" 25 26 class View(torch.nn.Module): 27 28 needs_transpose_tests = [ 29 (torch.rand(100), (1, -1, 5, 2)), 30 (torch.rand(10, 2, 1, 5), (1, -1, 5, 2)), 31 (torch.rand(1, 2, 1, 9), (3, 1, 3, 2)), 32 (torch.rand(2, 1, 1, 9), (3, 2, 3, 1)), 33 (torch.rand(2, 50, 2, 1), (1, 200)), 34 (torch.rand(2, 5, 2, 3), (1, 15, 4)), 35 ] 36 37 no_transpose_tests = [ 38 (torch.rand(2, 1, 1, 9), (3, 1, 3, 2)), 39 (torch.rand(5, 10, 1, 1), (25, 2, 1, 1)), 40 (torch.rand(10, 2), (1, 1, 5, 4)), 41 (torch.rand(10, 10), (5, 1, 5, 4)), 42 (torch.rand(1, 1, 1, 10), (1, 1, 10, 1)), 43 (torch.rand(1, 1, 5, 10), (1, 1, 50, 1)), 44 (torch.rand(5, 10, 1, 1), (1, 25, 2)), 45 (torch.rand(2, 50, 1, 1), (1, 100)), 46 ] 47 48 def forward(self, x: torch.Tensor, new_shape): 49 return x.view(new_shape) 50 51 def _test_view_tosa_MI_pipeline( 52 self, module: torch.nn.Module, test_data: torch.Tensor 53 ): 54 ( 55 ArmTester( 56 module, 57 example_inputs=test_data, 58 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 59 ) 60 .export() 61 .check_count({"torch.ops.aten.view.default": 1}) 62 .to_edge() 63 .partition() 64 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 65 .to_executorch() 66 .run_method_and_compare_outputs(inputs=test_data) 67 ) 68 69 def _test_view_tosa_BI_pipeline( 70 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 71 ): 72 ( 73 ArmTester( 74 module, 75 example_inputs=test_data, 76 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 77 ) 78 .quantize() 79 .export() 80 .check_count({"torch.ops.aten.view.default": 1}) 81 .to_edge() 82 .partition() 83 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 84 .to_executorch() 85 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 86 ) 87 88 def _test_view_ethos_BI_pipeline( 89 self, 90 compile_spec: list[CompileSpec], 91 module: torch.nn.Module, 92 test_data: Tuple[torch.Tensor], 93 ): 94 ( 95 ArmTester( 96 module, 97 example_inputs=test_data, 98 compile_spec=compile_spec, 99 ) 100 .quantize() 101 .export() 102 .check_count({"torch.ops.aten.view.default": 1}) 103 .to_edge() 104 .partition() 105 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 106 .to_executorch() 107 ) 108 109 def _test_view_u55_BI_pipeline( 110 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 111 ): 112 self._test_view_ethos_BI_pipeline( 113 common.get_u55_compile_spec(), module, test_data 114 ) 115 116 def _test_view_u85_BI_pipeline( 117 self, module: torch.nn.Module, test_data: Tuple[torch.Tensor] 118 ): 119 self._test_view_ethos_BI_pipeline( 120 common.get_u85_compile_spec(), module, test_data 121 ) 122 123 @parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests) 124 def test_view_tosa_MI(self, test_tensor: torch.Tensor, new_shape): 125 self._test_view_tosa_MI_pipeline(self.View(), (test_tensor, new_shape)) 126 127 @parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests) 128 def test_view_tosa_BI(self, test_tensor: torch.Tensor, new_shape): 129 self._test_view_tosa_BI_pipeline(self.View(), (test_tensor, new_shape)) 130 131 @parameterized.expand(View.no_transpose_tests) 132 def test_view_u55_BI(self, test_tensor: torch.Tensor, new_shape): 133 self._test_view_u55_BI_pipeline(self.View(), (test_tensor, new_shape)) 134 135 @parameterized.expand(View.needs_transpose_tests) 136 @unittest.expectedFailure 137 def test_view_transpose_u55_BI(self, test_tensor: torch.Tensor, new_shape): 138 self._test_view_u55_BI_pipeline(self.View(), (test_tensor, new_shape)) 139 140 @parameterized.expand(View.needs_transpose_tests + View.no_transpose_tests) 141 def test_view_u85_BI(self, test_tensor: torch.Tensor, new_shape): 142 self._test_view_u85_BI_pipeline(self.View(), (test_tensor, new_shape)) 143