1# Copyright 2023-2024 Arm Limited and/or its affiliates. 2# 3# This source code is licensed under the BSD-style license found in the 4# LICENSE file in the root directory of this source tree. 5 6# pyre-unsafe 7from typing import List 8 9import serializer.tosa_serializer as ts 10import torch 11from executorch.backends.arm.operators.node_visitor import ( 12 NodeVisitor, 13 register_node_visitor, 14) 15from executorch.backends.arm.tosa_mapping import TosaArg 16 17from executorch.backends.arm.tosa_quant_utils import ( 18 get_quant_arg_upstream, 19 quantize_value, 20) 21from serializer.tosa_serializer import TosaOp 22 23 24@register_node_visitor 25class HardTanhVisitor(NodeVisitor): 26 target = "aten.hardtanh.default" 27 28 def __init__(self, *args): 29 super().__init__(*args) 30 31 def define_node( 32 self, 33 node: torch.fx.Node, 34 tosa_graph: ts.TosaSerializer, 35 inputs: List[TosaArg], 36 output: TosaArg, 37 is_quant_node: bool, 38 ) -> None: 39 attr = ts.TosaSerializerAttribute() 40 41 if is_quant_node: 42 # Get quant parameters 43 qargs = get_quant_arg_upstream(node.all_input_nodes[0]) 44 # Convert to quantized representation 45 clamp_min_qs = quantize_value(inputs[1].number, qargs) 46 clamp_max_qs = quantize_value(inputs[2].number, qargs) 47 # Set fp values to 0.0 since they are not used 48 clamp_min_fp = 0.0 49 clamp_max_fp = 0.0 50 else: 51 clamp_min_fp = inputs[1].number 52 clamp_max_fp = inputs[2].number 53 # Set qs values to 0 since they are not used 54 clamp_min_qs = 0 55 clamp_max_qs = 0 56 57 attr.ClampAttribute( 58 tosa_graph.builder, 59 clamp_min_qs, 60 clamp_max_qs, 61 clamp_min_fp, 62 clamp_max_fp, 63 ) 64 65 tosa_graph.addOperator(TosaOp.Op().CLAMP, [inputs[0].name], [output.name], attr) 66