xref: /aosp_15_r20/external/executorch/backends/arm/test/ops/test_relu.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10from typing import Tuple
11
12import torch
13from executorch.backends.arm.quantizer.arm_quantizer import (
14    ArmQuantizer,
15    get_symmetric_quantization_config,
16)
17from executorch.backends.arm.test import common
18from executorch.backends.arm.test.tester.arm_tester import ArmTester
19from executorch.backends.xnnpack.test.tester.tester import Quantize
20from executorch.exir.backend.backend_details import CompileSpec
21from parameterized import parameterized
22
23
24test_data_suite = [
25    # (test_name, test_data)
26    ("zeros", torch.zeros(1, 10, 10, 10)),
27    ("ones", torch.ones(10, 10, 10)),
28    ("rand", torch.rand(10, 10) - 0.5),
29    ("randn_pos", torch.randn(10) + 10),
30    ("randn_neg", torch.randn(10) - 10),
31    ("ramp", torch.arange(-16, 16, 0.2)),
32]
33
34
35class TestRelu(unittest.TestCase):
36    class Relu(torch.nn.Module):
37        def __init__(self):
38            super().__init__()
39            self.relu = torch.nn.ReLU()
40
41        def forward(self, x):
42            return self.relu(x)
43
44    def _test_relu_tosa_MI_pipeline(
45        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
46    ):
47        (
48            ArmTester(
49                module,
50                example_inputs=test_data,
51                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+MI"),
52            )
53            .export()
54            .check(["torch.ops.aten.relu.default"])
55            .check_not(["torch.ops.quantized_decomposed"])
56            .to_edge()
57            .partition()
58            .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
59            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
60            .to_executorch()
61            .run_method_and_compare_outputs(inputs=test_data)
62        )
63
64    def _test_relu_tosa_BI_pipeline(
65        self, module: torch.nn.Module, test_data: Tuple[torch.tensor]
66    ):
67        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
68        (
69            ArmTester(
70                module,
71                example_inputs=test_data,
72                compile_spec=common.get_tosa_compile_spec("TOSA-0.80.0+BI"),
73            )
74            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
75            .export()
76            .check_count({"torch.ops.aten.relu.default": 1})
77            .check(["torch.ops.quantized_decomposed"])
78            .to_edge()
79            .partition()
80            .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
81            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
82            .to_executorch()
83            .run_method_and_compare_outputs(inputs=test_data)
84        )
85
86    def _test_relu_ethosu_BI_pipeline(
87        self,
88        compile_spec: CompileSpec,
89        module: torch.nn.Module,
90        test_data: Tuple[torch.tensor],
91    ):
92        quantizer = ArmQuantizer().set_io(get_symmetric_quantization_config())
93        (
94            ArmTester(
95                module,
96                example_inputs=test_data,
97                compile_spec=compile_spec,
98            )
99            .quantize(Quantize(quantizer, get_symmetric_quantization_config()))
100            .export()
101            .check_count({"torch.ops.aten.relu.default": 1})
102            .check(["torch.ops.quantized_decomposed"])
103            .to_edge()
104            .partition()
105            .check_not(["executorch_exir_dialects_edge__ops_aten_relu_default"])
106            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
107            .to_executorch()
108        )
109
110    @parameterized.expand(test_data_suite)
111    def test_relu_tosa_MI(
112        self,
113        test_name: str,
114        test_data: torch.Tensor,
115    ):
116        self._test_relu_tosa_MI_pipeline(self.Relu(), (test_data,))
117
118    @parameterized.expand(test_data_suite)
119    def test_relu_tosa_BI(self, test_name: str, test_data: torch.Tensor):
120        self._test_relu_tosa_BI_pipeline(self.Relu(), (test_data,))
121
122    @parameterized.expand(test_data_suite)
123    def test_relu_u55_BI(self, test_name: str, test_data: torch.Tensor):
124        self._test_relu_ethosu_BI_pipeline(
125            common.get_u55_compile_spec(), self.Relu(), (test_data,)
126        )
127
128    @parameterized.expand(test_data_suite)
129    def test_relu_u85_BI(self, test_name: str, test_data: torch.Tensor):
130        self._test_relu_ethosu_BI_pipeline(
131            common.get_u85_compile_spec(), self.Relu(), (test_data,)
132        )
133