1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import cast, Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import numpy as np 11import torch 12from executorch.backends.qualcomm.utils.constants import QCOM_DATA 13 14from .node_visitor import NodeVisitor, register_node_visitor 15from .qnn_constants import OpReluMinMax, QNN_OP_PACKAGE_NAME_QTI_AISW 16 17 18@register_node_visitor 19class Clamp(NodeVisitor): 20 target = ["aten.clamp.default"] 21 22 def __init__(self, *args) -> None: 23 super().__init__(*args) 24 25 def define_node( 26 self, 27 node: torch.fx.Node, 28 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 29 ) -> PyQnnWrapper.PyQnnOpWrapper: 30 input_node = node.args[0] 31 input_tensor = self.get_tensor(input_node, node) 32 input_tensor_wrapper = self.define_tensor( 33 input_node, 34 input_tensor, 35 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 36 nodes_to_wrappers, 37 is_input_tensor=True, 38 ) 39 40 # default value of output_min and output_max 41 output_min = torch.finfo(torch.float32).min 42 output_max = torch.finfo(torch.float32).max 43 44 if node.args[1] is not None: 45 # update output_min 46 output_min = cast(float, node.args[1]) 47 if len(node.args) > 2: 48 if node.args[2] is not None: 49 # update output_max 50 output_max = cast(float, node.args[2]) 51 52 output_tensor = self.get_tensor(node, node) 53 output_tensor_wrapper = self.define_tensor( 54 node, 55 output_tensor, 56 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 57 nodes_to_wrappers, 58 is_input_tensor=False, 59 ) 60 61 clamp_op = PyQnnWrapper.PyQnnOpWrapper( 62 node.name, 63 QNN_OP_PACKAGE_NAME_QTI_AISW, 64 OpReluMinMax.op_name, 65 ) 66 clamp_op.AddInputTensors([input_tensor_wrapper]) 67 clamp_op.AddOutputTensors([output_tensor_wrapper]) 68 clamp_op.AddScalarParam( 69 OpReluMinMax.param_max_value, 70 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, 71 {QCOM_DATA: np.float32(output_max)}, 72 ) 73 clamp_op.AddScalarParam( 74 OpReluMinMax.param_min_value, 75 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, 76 {QCOM_DATA: np.float32(output_min)}, 77 ) 78 79 return clamp_op 80