xref: /aosp_15_r20/external/executorch/backends/arm/operators/op_hardtanh.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2023-2024 Arm Limited and/or its affiliates.
2#
3# This source code is licensed under the BSD-style license found in the
4# LICENSE file in the root directory of this source tree.
5
6# pyre-unsafe
7from typing import List
8
9import serializer.tosa_serializer as ts
10import torch
11from executorch.backends.arm.operators.node_visitor import (
12    NodeVisitor,
13    register_node_visitor,
14)
15from executorch.backends.arm.tosa_mapping import TosaArg
16
17from executorch.backends.arm.tosa_quant_utils import (
18    get_quant_arg_upstream,
19    quantize_value,
20)
21from serializer.tosa_serializer import TosaOp
22
23
24@register_node_visitor
25class HardTanhVisitor(NodeVisitor):
26    target = "aten.hardtanh.default"
27
28    def __init__(self, *args):
29        super().__init__(*args)
30
31    def define_node(
32        self,
33        node: torch.fx.Node,
34        tosa_graph: ts.TosaSerializer,
35        inputs: List[TosaArg],
36        output: TosaArg,
37        is_quant_node: bool,
38    ) -> None:
39        attr = ts.TosaSerializerAttribute()
40
41        if is_quant_node:
42            # Get quant parameters
43            qargs = get_quant_arg_upstream(node.all_input_nodes[0])
44            # Convert to quantized representation
45            clamp_min_qs = quantize_value(inputs[1].number, qargs)
46            clamp_max_qs = quantize_value(inputs[2].number, qargs)
47            # Set fp values to 0.0 since they are not used
48            clamp_min_fp = 0.0
49            clamp_max_fp = 0.0
50        else:
51            clamp_min_fp = inputs[1].number
52            clamp_max_fp = inputs[2].number
53            # Set qs values to 0 since they are not used
54            clamp_min_qs = 0
55            clamp_max_qs = 0
56
57        attr.ClampAttribute(
58            tosa_graph.builder,
59            clamp_min_qs,
60            clamp_max_qs,
61            clamp_min_fp,
62            clamp_max_fp,
63        )
64
65        tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr)
66