1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import torch 11from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS 12from executorch.exir.dialects._ops import ops as exir_ops 13 14from .node_visitor import NodeVisitor, register_node_visitor 15from .qnn_constants import OpElementWisePower, QNN_OP_PACKAGE_NAME_QTI_AISW 16 17 18# TODO Add more class Like PowTensorTensor if needed 19@register_node_visitor 20class PowTensorScalar(NodeVisitor): 21 target = ["aten.pow.Tensor_Scalar"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 out_tensor = self.get_tensor(node, node) 32 output_tensor_wrapper = self.define_tensor( 33 node, 34 out_tensor, 35 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 36 nodes_to_wrappers, 37 is_input_tensor=False, 38 ) 39 pow_output_tensors = [output_tensor_wrapper] 40 41 # tensor input 42 input_node = node.args[0] 43 input_tensor = self.get_tensor(input_node, node) 44 45 tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE 46 47 input_tensor_wrapper = self.define_tensor( 48 input_node, 49 input_tensor, 50 tensor_type, 51 nodes_to_wrappers, 52 is_input_tensor=True, 53 ) 54 55 # scalar input 56 scalar = node.args[1] 57 scalar_tensor = torch.tensor(scalar).to(torch.float32) 58 59 # 'graph', 'name', 'op', 'target', 'args', and 'kwargs' 60 scalar_node = torch.fx.Node( 61 node.graph, 62 node.name + "_runtime_scalar", 63 "call_function", 64 exir_ops.edge.aten.scalar_tensor.default, 65 (), # args 66 {}, # kwargs 67 ) 68 69 if pow_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): 70 quant_attrs = pow_quant_attrs.copy() 71 quant_range = quant_attrs["quant_max"] - quant_attrs["quant_min"] 72 quant_attrs["zero_point"] = 0 if scalar >= 0 else quant_attrs["quant_max"] 73 quant_attrs["scale"] = ( 74 scalar / quant_range if scalar >= 0 else -scalar / quant_range 75 ) 76 scalar_node.meta[QCOM_QUANT_ATTRS] = quant_attrs 77 78 scalar_tensor_wrapper = self.define_tensor( 79 scalar_node, 80 scalar_tensor, 81 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 82 nodes_to_wrappers, 83 is_input_tensor=False, 84 ) 85 86 pow_input_tensors = [input_tensor_wrapper, scalar_tensor_wrapper] 87 88 pow_op = PyQnnWrapper.PyQnnOpWrapper( 89 node.name, 90 QNN_OP_PACKAGE_NAME_QTI_AISW, 91 OpElementWisePower.op_name, 92 ) 93 pow_op.AddInputTensors(pow_input_tensors) 94 pow_op.AddOutputTensors(pow_output_tensors) 95 96 return pow_op 97