1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import warnings 7from typing import cast, Dict 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA 14 15from .node_visitor import NodeVisitor, register_node_visitor 16from .qnn_constants import OpTopK, QNN_OP_PACKAGE_NAME_QTI_AISW 17 18 19@register_node_visitor 20class TopK(NodeVisitor): 21 target = ["aten.topk.default"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 32 input_node = node.args[0] 33 input_tensor = self.get_tensor(input_node, node) 34 input_tensor_wrapper = self.define_tensor( 35 input_node, 36 input_tensor, 37 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 38 nodes_to_wrappers, 39 is_input_tensor=True, 40 ) 41 42 k = cast(int, node.args[1]) 43 44 if len(node.args) > 2: 45 dim = cast(int, node.args[2]) 46 if dim < 0: 47 dim = dim % len(input_tensor.shape) 48 if QCOM_AXIS_ORDER in node.meta: 49 dim = node.meta[QCOM_AXIS_ORDER].index(dim) 50 if dim != len(input_tensor.shape) - 1: 51 warnings.warn( 52 "[QNN Delegate Op Builder]: QNN currently only supports channel as dimension for topK.", 53 stacklevel=1, 54 ) 55 return 56 57 topk_input_tensors = [input_tensor_wrapper] 58 59 output_val_tensor = self.get_tensor(node, node, 0) 60 output_idx_tensor = self.get_tensor(node, node, 1).to(torch.int32) 61 62 # QNN constraint, topk output_0 requires having the same quant config as input 63 node.meta["quant_attrs"] = input_node.meta.get("quant_attrs") 64 output_val_tensor_wrapper = self.define_tensor( 65 node, 66 output_val_tensor, 67 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 68 nodes_to_wrappers, 69 is_input_tensor=False, 70 ) 71 72 # topk output_1 is index, do not quantize it. 73 node.meta.pop("quant_attrs", None) 74 output_index_tensor_wrapper = self.define_tensor( 75 node, 76 output_idx_tensor, 77 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 78 nodes_to_wrappers, 79 is_input_tensor=False, 80 wrapper_idx=1, 81 ) 82 topk_output_tensors = [output_val_tensor_wrapper, output_index_tensor_wrapper] 83 84 topk_op = PyQnnWrapper.PyQnnOpWrapper( 85 node.name, 86 QNN_OP_PACKAGE_NAME_QTI_AISW, 87 OpTopK.op_name, 88 ) 89 topk_op.AddInputTensors(topk_input_tensors) 90 topk_op.AddOutputTensors(topk_output_tensors) 91 92 topk_op.AddScalarParam( 93 OpTopK.param_k, 94 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 95 {"data": np.uint32(k)}, 96 ) 97 98 # As of QNN 2.26, QNN HTP backend only allows users to set this value to 1, or else it will fail at op validation 99 if len(node.args) > 3: 100 largest = cast(bool, node.args[3]) 101 topk_op.AddScalarParam( 102 OpTopK.param_largest, 103 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, 104 {QCOM_DATA: largest}, 105 ) 106 107 return topk_op 108