1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10from typing import Tuple 11 12import torch 13 14from executorch.backends.arm.test import common 15from executorch.backends.arm.test.tester.arm_tester import ArmTester 16from executorch.exir.backend.compile_spec_schema import CompileSpec 17from parameterized import parameterized 18 19 20test_data_suite = [ 21 # (test_name, test_data) 22 ("zeros", torch.zeros(10, 10, 10, 10)), 23 ("ones", torch.ones(10, 10, 10)), 24 ("rand", torch.rand(10, 10) - 0.5), 25 ("randn_pos", torch.randn(10) + 10), 26 ("randn_neg", torch.randn(10) - 10), 27 ("ramp", torch.arange(-16, 16, 0.2)), 28] 29 30 31class TestTanh(unittest.TestCase): 32 class Tanh(torch.nn.Module): 33 def __init__(self): 34 super().__init__() 35 self.tanh = torch.nn.Tanh() 36 37 def forward(self, x): 38 return self.tanh(x) 39 40 def _test_tanh_tosa_MI_pipeline( 41 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 42 ): 43 ( 44 ArmTester( 45 module, 46 example_inputs=test_data, 47 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"), 48 ) 49 .export() 50 .check(["torch.ops.aten.tanh.default"]) 51 .check_not(["torch.ops.quantized_decomposed"]) 52 .to_edge() 53 .partition() 54 .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) 55 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 56 .to_executorch() 57 .run_method_and_compare_outputs(inputs=test_data) 58 ) 59 60 def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple): 61 ( 62 ArmTester( 63 module, 64 example_inputs=test_data, 65 compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"), 66 ) 67 .quantize() 68 .export() 69 .check(["torch.ops.aten.tanh.default"]) 70 .check(["torch.ops.quantized_decomposed"]) 71 .to_edge() 72 .partition() 73 .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) 74 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 75 .to_executorch() 76 .run_method_and_compare_outputs(inputs=test_data) 77 ) 78 79 def _test_tanh_tosa_ethos_BI_pipeline( 80 self, 81 compile_spec: list[CompileSpec], 82 module: torch.nn.Module, 83 test_data: Tuple[torch.tensor], 84 ): 85 ( 86 ArmTester( 87 module, 88 example_inputs=test_data, 89 compile_spec=compile_spec, 90 ) 91 .quantize() 92 .export() 93 .check_count({"torch.ops.aten.tanh.default": 1}) 94 .check(["torch.ops.quantized_decomposed"]) 95 .to_edge() 96 .partition() 97 .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"]) 98 .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) 99 .to_executorch() 100 ) 101 102 def _test_tanh_tosa_u55_BI_pipeline( 103 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 104 ): 105 self._test_tanh_tosa_ethos_BI_pipeline( 106 common.get_u55_compile_spec(), module, test_data 107 ) 108 109 def _test_tanh_tosa_u85_BI_pipeline( 110 self, module: torch.nn.Module, test_data: Tuple[torch.tensor] 111 ): 112 self._test_tanh_tosa_ethos_BI_pipeline( 113 common.get_u85_compile_spec(), module, test_data 114 ) 115 116 @parameterized.expand(test_data_suite) 117 def test_tanh_tosa_MI( 118 self, 119 test_name: str, 120 test_data: torch.Tensor, 121 ): 122 self._test_tanh_tosa_MI_pipeline(self.Tanh(), (test_data,)) 123 124 @parameterized.expand(test_data_suite) 125 def test_tanh_tosa_BI(self, test_name: str, test_data: torch.Tensor): 126 self._test_tanh_tosa_BI_pipeline(self.Tanh(), (test_data,)) 127 128 @parameterized.expand(test_data_suite) 129 def test_tanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor): 130 self._test_tanh_tosa_u55_BI_pipeline(self.Tanh(), (test_data,)) 131 132 @parameterized.expand(test_data_suite) 133 def test_tanh_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor): 134 self._test_tanh_tosa_u85_BI_pipeline(self.Tanh(), (test_data,)) 135