1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import torch 11from executorch.backends.qualcomm.utils.constants import ( 12 QCOM_AXIS_ORDER, 13 QCOM_QUANT_ATTRS, 14 QCOM_QUANT_MAX, 15 QCOM_QUANT_MIN, 16 QCOM_SCALE, 17 QCOM_ZERO_POINT, 18) 19from executorch.exir.dialects._ops import ops as exir_ops 20 21from .node_visitor import get_parameter, NodeVisitor, register_node_visitor 22from .qnn_constants import OpPRelu, QNN_OP_PACKAGE_NAME_QTI_AISW 23 24 25@register_node_visitor 26class PReLU(NodeVisitor): 27 target = ["aten.leaky_relu.default", "aten.prelu.default"] 28 29 def __init__(self, *args) -> None: 30 super().__init__(*args) 31 32 def define_node( 33 self, 34 node: torch.fx.Node, 35 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 36 ) -> PyQnnWrapper.PyQnnOpWrapper: 37 input_node = node.args[0] 38 input_tensor = self.get_tensor(input_node, node) 39 prelu_inp_tensor_wrapper = self.define_tensor( 40 input_node, 41 input_tensor, 42 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 43 nodes_to_wrappers, 44 is_input_tensor=True, 45 ) 46 47 if node.target.__name__ == "aten.leaky_relu.default": 48 coeff = 1e-2 if len(node.args) < 2 else node.args[1] 49 coeff_tensor = torch.full(input_tensor.shape, coeff).to(torch.float32) 50 else: 51 coeff_node = node.args[1] 52 coeff_tensor = torch.zeros(input_node.meta["val"].shape) 53 coeff = get_parameter(coeff_node, self.edge_program) 54 # param nodes will be FakeTensor when doing partition 55 # fill in random numeric for validation 56 if isinstance(coeff, torch._subclasses.fake_tensor.FakeTensor): 57 coeff = torch.ones(coeff.shape) 58 # per-channel activation 59 if coeff_node.meta["val"].shape[0] > 1: 60 for i in range(input_node.meta["val"].shape[1]): 61 coeff_tensor = coeff_tensor.index_fill( 62 1, torch.tensor([i]), coeff[i] 63 ) 64 if QCOM_AXIS_ORDER in input_node.meta: 65 axis_order = input_node.meta[QCOM_AXIS_ORDER] 66 coeff_tensor = coeff_tensor.permute(dims=axis_order).contiguous() 67 # simple min-max quantization 68 coeff = torch.max(coeff).item() 69 else: 70 coeff = coeff.item() 71 coeff_tensor = torch.full(input_tensor.shape, coeff).to(torch.float32) 72 73 # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' 74 scalar_node = torch.fx.Node( 75 node.graph, 76 node.name + "_runtime_scalar", 77 "call_function", 78 exir_ops.edge.aten.full.default, 79 (), # args 80 {}, # kwargs 81 ) 82 if pow_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): 83 quant_attrs = pow_quant_attrs.copy() 84 quant_range = quant_attrs[QCOM_QUANT_MAX] - quant_attrs[QCOM_QUANT_MIN] 85 # coeff is guaranteed to be positive 86 quant_attrs[QCOM_ZERO_POINT] = 0 87 quant_attrs[QCOM_SCALE] = coeff / quant_range 88 scalar_node.meta[QCOM_QUANT_ATTRS] = quant_attrs 89 90 scalar_tensor_wrapper = self.define_tensor( 91 scalar_node, 92 coeff_tensor, 93 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 94 nodes_to_wrappers, 95 is_input_tensor=True, 96 ) 97 prelu_input_tensors = [prelu_inp_tensor_wrapper, scalar_tensor_wrapper] 98 99 output_tensor = self.get_tensor(node, node) 100 output_tensor_wrapper = self.define_tensor( 101 node, 102 output_tensor, 103 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 104 nodes_to_wrappers, 105 is_input_tensor=False, 106 ) 107 prelu_output_tensors = [output_tensor_wrapper] 108 109 prelu_op = PyQnnWrapper.PyQnnOpWrapper( 110 node.name, 111 QNN_OP_PACKAGE_NAME_QTI_AISW, 112 OpPRelu.op_name, 113 ) 114 prelu_op.AddInputTensors(prelu_input_tensors) 115 prelu_op.AddOutputTensors(prelu_output_tensors) 116 117 return prelu_op 118