1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import torch 11from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS 12 13from .node_visitor import NodeVisitor, register_node_visitor 14from .qnn_constants import OpCast, OpConvert, QNN_OP_PACKAGE_NAME_QTI_AISW 15 16 17@register_node_visitor 18class To(NodeVisitor): 19 target = ["aten._to_copy.default"] 20 sufixed_8_offset_diff = 128 21 sufixed_16_offset_diff = 32768 22 epsilon = 1e-6 23 sufixed_8 = { 24 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_8, 25 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_8, 26 } 27 sufixed_16 = { 28 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_SFIXED_POINT_16, 29 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UFIXED_POINT_16, 30 } 31 32 def __init__(self, *args) -> None: 33 super().__init__(*args) 34 35 def is_cast_node(self, node): 36 input_node = node.args[0] 37 38 # Not a case which has two quant node, no need to consider the convert op 39 if not all( 40 [ 41 input_node.meta.get(QCOM_QUANT_ATTRS), 42 node.meta.get(QCOM_QUANT_ATTRS), 43 ] 44 ): 45 return True 46 47 input_tensor = self.get_tensor(input_node, node) 48 _, inp_qconfs = self.get_quant_encoding_conf(input_node, False) 49 inp_dtype = self.get_data_type(input_tensor, inp_qconfs) 50 51 output_tensor = self.get_tensor(node, node) 52 _, out_qconfs = self.get_quant_encoding_conf(node, False) 53 out_dtype = self.get_data_type(output_tensor, out_qconfs) 54 is_qparam_castable = ( 55 lambda o1, o2, s1, s2, diff: abs(s1 - s2) < self.epsilon 56 and abs(o1 - o2) == diff 57 ) 58 59 if {inp_dtype, out_dtype} == self.sufixed_8: 60 return is_qparam_castable( 61 inp_qconfs["offset"], 62 out_qconfs["offset"], 63 inp_qconfs["scale"], 64 out_qconfs["scale"], 65 self.sufixed_8_offset_diff, 66 ) 67 elif {inp_dtype, out_dtype} == self.sufixed_16: 68 return is_qparam_castable( 69 inp_qconfs["offset"], 70 out_qconfs["offset"], 71 inp_qconfs["scale"], 72 out_qconfs["scale"], 73 self.sufixed_16_offset_diff, 74 ) 75 return False 76 77 def define_node( 78 self, 79 node: torch.fx.Node, 80 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 81 ) -> PyQnnWrapper.PyQnnOpWrapper: 82 input_node = node.args[0] 83 input_tensor = self.get_tensor(input_node, node) 84 85 input_tensor_wrapper = self.define_tensor( 86 input_node, 87 input_tensor, 88 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 89 nodes_to_wrappers, 90 is_input_tensor=True, 91 ) 92 93 output_tensor = self.get_tensor(node, node) 94 95 output_tensor_wrapper = self.define_tensor( 96 node, 97 output_tensor, 98 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 99 nodes_to_wrappers, 100 is_input_tensor=False, 101 ) 102 103 qnn_op = OpCast if self.is_cast_node(node) else OpConvert 104 op = PyQnnWrapper.PyQnnOpWrapper( 105 node.name, QNN_OP_PACKAGE_NAME_QTI_AISW, qnn_op.op_name 106 ) 107 op.AddInputTensors([input_tensor_wrapper]) 108 op.AddOutputTensors([output_tensor_wrapper]) 109 110 return op 111