1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import torch 11 12from .node_visitor import NodeVisitor, register_node_visitor 13from .qnn_constants import OpDequantize, QNN_OP_PACKAGE_NAME_QTI_AISW 14 15 16class DequantizeOpBase(NodeVisitor): 17 def __init__(self, *args) -> None: 18 super().__init__(*args) 19 20 def define_node( 21 self, 22 node: torch.fx.Node, 23 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 24 ) -> PyQnnWrapper.PyQnnOpWrapper: 25 dequant_input_tensors = [] 26 input_node = node.args[0] 27 input_tensor = self.get_tensor(input_node, node) 28 inp_tensor_wrapper = self.define_tensor( 29 input_node, 30 input_tensor, 31 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 32 nodes_to_wrappers, 33 is_input_tensor=True, 34 ) 35 dequant_input_tensors.append(inp_tensor_wrapper) 36 37 output_tensor = self.get_tensor(node, node) 38 output_tensor_wrapper = self.define_tensor( 39 node, 40 output_tensor, 41 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 42 nodes_to_wrappers, 43 is_input_tensor=False, 44 ) 45 dequant_output_tensors = [output_tensor_wrapper] 46 47 dequant_op = PyQnnWrapper.PyQnnOpWrapper( 48 node.target.__name__, 49 QNN_OP_PACKAGE_NAME_QTI_AISW, 50 OpDequantize.op_name, 51 ) 52 dequant_op.AddInputTensors(dequant_input_tensors) 53 dequant_op.AddOutputTensors(dequant_output_tensors) 54 55 return dequant_op 56 57 58@register_node_visitor 59class PerTensorDequantize(DequantizeOpBase): 60 target = [ 61 "quantized_decomposed.dequantize_per_tensor.default", 62 "quantized_decomposed.dequantize_per_tensor.tensor", 63 ] 64 65 66@register_node_visitor 67class PerChannelDequantize(DequantizeOpBase): 68 target = [ 69 "quantized_decomposed.dequantize_per_channel.default", 70 "quantized_decomposed.dequantize_per_channel.tensor", 71 ] 72