xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_tanh.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10from typing import Tuple
11
12import torch
13
14from executorch.backends.arm.test import common
15from executorch.backends.arm.test.tester.arm_tester import ArmTester
16from executorch.exir.backend.compile_spec_schema import CompileSpec
17from parameterized import parameterized
18
19
20test_data_suite = [
21    # (test_name, test_data)
22    ("zeros", torch.zeros(10, 10, 10, 10)),
23    ("ones", torch.ones(10, 10, 10)),
24    ("rand", torch.rand(10, 10) - 0.5),
25    ("randn_pos", torch.randn(10) + 10),
26    ("randn_neg", torch.randn(10) - 10),
27    ("ramp", torch.arange(-16, 16, 0.2)),
28]
29
30
31class TestTanh(unittest.TestCase):
32    class Tanh(torch.nn.Module):
33        def __init__(self):
34            super().__init__()
35            self.tanh = torch.nn.Tanh()
36
37        def forward(self, x):
38            return self.tanh(x)
39
40    def _test_tanh_tosa_MI_pipeline(
41        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
42    ):
43        (
44            ArmTester(
45                module,
46                example_inputs=test_data,
47                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
48            )
49            .export()
50            .check(["torch.ops.aten.tanh.default"])
51            .check_not(["torch.ops.quantized_decomposed"])
52            .to_edge()
53            .partition()
54            .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
55            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
56            .to_executorch()
57            .run_method_and_compare_outputs(inputs=test_data)
58        )
59
60    def _test_tanh_tosa_BI_pipeline(self, module: torch.nn.Module, test_data: Tuple):
61        (
62            ArmTester(
63                module,
64                example_inputs=test_data,
65                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
66            )
67            .quantize()
68            .export()
69            .check(["torch.ops.aten.tanh.default"])
70            .check(["torch.ops.quantized_decomposed"])
71            .to_edge()
72            .partition()
73            .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
74            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
75            .to_executorch()
76            .run_method_and_compare_outputs(inputs=test_data)
77        )
78
79    def _test_tanh_tosa_ethos_BI_pipeline(
80        self,
81        compile_spec: list[CompileSpec],
82        module: torch.nn.Module,
83        test_data: Tuple[torch.tensor],
84    ):
85        (
86            ArmTester(
87                module,
88                example_inputs=test_data,
89                compile_spec=compile_spec,
90            )
91            .quantize()
92            .export()
93            .check_count({"torch.ops.aten.tanh.default": 1})
94            .check(["torch.ops.quantized_decomposed"])
95            .to_edge()
96            .partition()
97            .check_not(["executorch_exir_dialects_edge__ops_aten_tanh_default"])
98            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
99            .to_executorch()
100        )
101
102    def _test_tanh_tosa_u55_BI_pipeline(
103        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
104    ):
105        self._test_tanh_tosa_ethos_BI_pipeline(
106            common.get_u55_compile_spec(), module, test_data
107        )
108
109    def _test_tanh_tosa_u85_BI_pipeline(
110        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
111    ):
112        self._test_tanh_tosa_ethos_BI_pipeline(
113            common.get_u85_compile_spec(), module, test_data
114        )
115
116    @parameterized.expand(test_data_suite)
117    def test_tanh_tosa_MI(
118        self,
119        test_name: str,
120        test_data: torch.Tensor,
121    ):
122        self._test_tanh_tosa_MI_pipeline(self.Tanh(), (test_data,))
123
124    @parameterized.expand(test_data_suite)
125    def test_tanh_tosa_BI(self, test_name: str, test_data: torch.Tensor):
126        self._test_tanh_tosa_BI_pipeline(self.Tanh(), (test_data,))
127
128    @parameterized.expand(test_data_suite)
129    def test_tanh_tosa_u55_BI(self, test_name: str, test_data: torch.Tensor):
130        self._test_tanh_tosa_u55_BI_pipeline(self.Tanh(), (test_data,))
131
132    @parameterized.expand(test_data_suite)
133    def test_tanh_tosa_u85_BI(self, test_name: str, test_data: torch.Tensor):
134        self._test_tanh_tosa_u85_BI_pipeline(self.Tanh(), (test_data,))
135