1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10from typing import Tuple 11 12import torch 13from executorch.backends.arm.test import common 14 15from executorch.backends.arm.test.tester.arm_tester import ArmTester 16from executorch.exir.backend.compile_spec_schema import CompileSpec 17from parameterized import parameterized 18 19 20class TestCat(unittest.TestCase): 21 22 class Cat(torch.nn.Module): 23 test_parameters = [ 24 ((torch.ones(1), torch.ones(1)), 0), 25 ((torch.ones(1, 2), torch.randn(1, 5), torch.randn(1, 1)), 1), 26 ( 27 ( 28 torch.ones(1, 2, 5), 29 torch.randn(1, 2, 4), 30 torch.randn(1, 2, 2), 31 torch.randn(1, 2, 1), 32 ), 33 -1, 34 ), 35 ((torch.randn(2, 2, 4, 4), torch.randn(2, 2, 4, 1)), 3), 36 ( 37 ( 38 10000 * torch.randn(2, 3, 1, 4), 39 torch.randn(2, 7, 1, 4), 40 torch.randn(2, 1, 1, 4), 41 ), 42 -3, 43 ), 44 ] 45 46 def __init__(self): 47 super().__init__() 48 49 def forward(self, tensors: tuple[torch.Tensor, ...], dim: int) -> torch.Tensor: 50 return torch.cat(tensors, dim=dim) 51 52 def _test_cat_tosa_MI_pipeline( 53 self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int] 54 ): 55 ( 56 ArmTester( 57 module, 58 example_inputs=test_data, 59 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 60 ) 61 .export() 62 .check_count({"torch.ops.aten.cat.default": 1}) 63 .check_not(["torch.ops.quantized_decomposed"]) 64 .to_edge() 65 .partition() 66 .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"]) 67 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 68 .to_executorch() 69 .run_method_and_compare_outputs(inputs=test_data) 70 ) 71 72 def _test_cat_tosa_BI_pipeline( 73 self, module: torch.nn.Module, test_data: Tuple[tuple[torch.Tensor, ...], int] 74 ): 75 ( 76 ArmTester( 77 module, 78 example_inputs=test_data, 79 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 80 ) 81 .quantize() 82 .export() 83 .check_count({"torch.ops.aten.cat.default": 1}) 84 .check(["torch.ops.quantized_decomposed"]) 85 .to_edge() 86 .partition() 87 .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"]) 88 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 89 .to_executorch() 90 .run_method_and_compare_outputs(inputs=test_data, qtol=1) 91 ) 92 93 def _test_cat_ethosu_BI_pipeline( 94 self, 95 module: torch.nn.Module, 96 compile_spec: CompileSpec, 97 test_data: Tuple[tuple[torch.Tensor, ...], int], 98 ): 99 ( 100 ArmTester( 101 module, 102 example_inputs=test_data, 103 compile_spec=compile_spec, 104 ) 105 .quantize() 106 .export() 107 .check_count({"torch.ops.aten.cat.default": 1}) 108 .check(["torch.ops.quantized_decomposed"]) 109 .to_edge() 110 .partition() 111 .check_not(["executorch_exir_dialects_edge__ops_aten_cat_default"]) 112 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 113 .to_executorch() 114 ) 115 116 @parameterized.expand(Cat.test_parameters) 117 def test_cat_tosa_MI(self, operands: tuple[torch.Tensor, ...], dim: int): 118 test_data = (operands, dim) 119 self._test_cat_tosa_MI_pipeline(self.Cat(), test_data) 120 121 def test_cat_4d_tosa_MI(self): 122 square = torch.ones((2, 2, 2, 2)) 123 for dim in range(-3, 3): 124 test_data = ((square, square), dim) 125 self._test_cat_tosa_MI_pipeline(self.Cat(), test_data) 126 127 @parameterized.expand(Cat.test_parameters) 128 def test_cat_tosa_BI(self, operands: tuple[torch.Tensor, ...], dim: int): 129 test_data = (operands, dim) 130 self._test_cat_tosa_BI_pipeline(self.Cat(), test_data) 131 132 @parameterized.expand(Cat.test_parameters) 133 def test_cat_u55_BI(self, operands: tuple[torch.Tensor, ...], dim: int): 134 test_data = (operands, dim) 135 self._test_cat_ethosu_BI_pipeline( 136 self.Cat(), common.get_u55_compile_spec(), test_data 137 ) 138 139 @parameterized.expand(Cat.test_parameters) 140 def test_cat_u85_BI(self, operands: tuple[torch.Tensor, ...], dim: int): 141 test_data = (operands, dim) 142 self._test_cat_ethosu_BI_pipeline( 143 self.Cat(), common.get_u85_compile_spec(), test_data 144 ) 145