xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/hardtanh.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestHardTanh(unittest.TestCase):
14    class HardTanh(torch.nn.Module):
15        def __init__(self, min_val=-1.0, max_val=1.0):
16            super().__init__()
17            self.min_val = min_val
18            self.max_val = max_val
19
20        def forward(self, x):
21            y = x + x
22            z = torch.nn.Hardtanh(self.min_val, self.max_val)(y)
23            return z
24
25    def test_fp32_hardtanh(self):
26        inputs_sets = [torch.randn(2, 3, 4), torch.randn(7, 5, 2), torch.randn(2, 9)]
27        for input in inputs_sets:
28            (
29                Tester(self.HardTanh(), (input,))
30                .export()
31                .check_count({"torch.ops.aten.hardtanh.default": 1})
32                .to_edge_transform_and_lower()
33                .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
34                .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
35                .to_executorch()
36                .serialize()
37                .run_method_and_compare_outputs()
38            )
39
40    def test_fp32_hardtanh_bound(self):
41        inputs_sets = [torch.randn(2, 3, 4), torch.randn(7, 5, 2), torch.randn(2, 9)]
42        for input in inputs_sets:
43            (
44                Tester(self.HardTanh(-2.0, 2.0), (input,))
45                .export()
46                .check_count({"torch.ops.aten.hardtanh.default": 1})
47                .to_edge_transform_and_lower()
48                .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
49                .check_not(["executorch_exir_dialects_edge__ops_aten_hardtanh_default"])
50                .to_executorch()
51                .serialize()
52                .run_method_and_compare_outputs()
53            )
54
55    def test_qs8_hardtanh(self):
56        inputs_sets = [torch.randn(2, 3, 2), torch.randn(2, 1, 2), torch.randn(2, 3)]
57        for input in inputs_sets:
58            (
59                Tester(self.HardTanh(), (input,))
60                .quantize()
61                .export()
62                .check_node_count(
63                    {
64                        # Expect three quantize ops - one for input, hardtanh, and add.
65                        torch.ops.quantized_decomposed.quantize_per_tensor.default: 3,
66                        torch.ops.aten.hardtanh.default: 1,
67                    }
68                )
69                .to_edge_transform_and_lower()
70                .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
71                .check_not(
72                    [
73                        "executorch_exir_dialects_edge__ops_aten_hardtanh_default",
74                        "torch.ops.quantized_decomposed",
75                    ]
76                )
77                .to_executorch()
78                .serialize()
79                .run_method_and_compare_outputs()
80            )
81