1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import torch 11from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS 12 13from .node_visitor import NodeVisitor, register_node_visitor 14from .qnn_constants import OpQuantize, QNN_OP_PACKAGE_NAME_QTI_AISW 15 16 17class QuantizeOpBase(NodeVisitor): 18 def __init__(self, *args) -> None: 19 super().__init__(*args) 20 21 def define_node( 22 self, 23 node: torch.fx.Node, 24 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 25 ) -> PyQnnWrapper.PyQnnOpWrapper: 26 quant_input_tensors = [] 27 input_node = node.args[0] 28 input_tensor = self.get_tensor(input_node, node) 29 inp_tensor_wrapper = self.define_tensor( 30 input_node, 31 input_tensor, 32 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 33 nodes_to_wrappers, 34 is_input_tensor=True, 35 ) 36 quant_input_tensors.append(inp_tensor_wrapper) 37 38 node.meta[QCOM_QUANT_ATTRS] = {QCOM_ENCODING: node.target} 39 arg_schemas = list(node.target._schema.arguments)[1:] 40 for i, arg_schema in enumerate(arg_schemas): 41 name = arg_schema.name 42 node.meta[QCOM_QUANT_ATTRS][name] = node.args[i + 1] 43 44 output_tensor = self.get_tensor(node, node) 45 output_tensor_wrapper = self.define_tensor( 46 node, 47 output_tensor, 48 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 49 nodes_to_wrappers, 50 is_input_tensor=False, 51 ) 52 quant_output_tensors = [output_tensor_wrapper] 53 54 quant_op = PyQnnWrapper.PyQnnOpWrapper( 55 node.target.__name__, 56 QNN_OP_PACKAGE_NAME_QTI_AISW, 57 OpQuantize.op_name, 58 ) 59 quant_op.AddInputTensors(quant_input_tensors) 60 quant_op.AddOutputTensors(quant_output_tensors) 61 62 return quant_op 63 64 65@register_node_visitor 66class PerTensorQuantize(QuantizeOpBase): 67 target = ["quantized_decomposed.quantize_per_tensor.default"] 68 69 70@register_node_visitor 71class PerChannelQuantize(QuantizeOpBase): 72 target = ["quantized_decomposed.quantize_per_channel.default"] 73