1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict 8 9import torch 10from executorch.backends.xnnpack.operators.node_visitor import ( 11 NodeVisitor, 12 register_node_visitor, 13) 14from executorch.backends.xnnpack.serialization.xnnpack_graph_schema import ( 15 OutputMinMax, 16 XNNClamp, 17 XNNGraph, 18 XNode, 19) 20from executorch.backends.xnnpack.utils.utils import get_input_node 21 22 23@register_node_visitor 24class HardTanhVisitor(NodeVisitor): 25 target = "aten.hardtanh.default" 26 27 def __init__(self, *args) -> None: 28 super().__init__(*args) 29 30 def define_node( 31 self, 32 node: torch.fx.Node, 33 xnn_graph: XNNGraph, 34 vals_to_ids: Dict[torch.fx.Node, int], 35 debug_handle: int, 36 ) -> None: 37 self.define_nodes_tensor_inputs_outputs(node, xnn_graph, vals_to_ids) 38 39 # default value of output_min and output_max 40 output_min = -1 41 output_max = 1 42 43 if len(node.args) > 1: 44 # update output_min 45 output_min = cast(float, node.args[1]) 46 47 # update output_max 48 output_max = cast(float, node.args[2]) 49 50 # input_id 51 input_id = vals_to_ids[get_input_node(node, 0)] 52 53 # output 54 output_id = vals_to_ids[node] 55 56 output_min_max = OutputMinMax(output_min=output_min, output_max=output_max) 57 58 ser_node = XNode( 59 xnode_union=XNNClamp( 60 input_id=input_id, 61 output_id=output_id, 62 flags=0, 63 ), 64 debug_handle=debug_handle, 65 output_min_max=output_min_max, 66 ) 67 xnn_graph.xnodes.append(ser_node) 68