1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import warnings 7from typing import cast, Dict, List 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA 14 15from .node_visitor import NodeVisitor, register_node_visitor 16from .qnn_constants import OpConcat, QNN_OP_PACKAGE_NAME_QTI_AISW 17 18 19@register_node_visitor 20class Cat(NodeVisitor): 21 target = ["aten.cat.default"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 list_of_tensors = cast(List[torch.fx.Node], node.args[0]) 32 list_of_tensor_wrappers = [] 33 34 for tensor_input in list_of_tensors: 35 input_tensor = self.get_tensor(tensor_input, node) 36 list_of_tensor_wrappers.append( 37 self.define_tensor( 38 tensor_input, 39 input_tensor, 40 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 41 nodes_to_wrappers, 42 is_input_tensor=True, 43 ) 44 ) 45 46 if len(list_of_tensors) != len(list_of_tensor_wrappers): 47 warnings.warn( 48 "[QNN Delegate Op Builder]: The number or input tensors is not equal to the number of input tensor wrappers.", 49 stacklevel=1, 50 ) 51 return 52 53 output_tensor = self.get_tensor(node, node) 54 output_tensor_wrapper = self.define_tensor( 55 node, 56 output_tensor, 57 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 58 nodes_to_wrappers, 59 is_input_tensor=False, 60 ) 61 62 # node args[1] might not exist 63 axis = 0 64 if len(node.args) == 2: 65 axis = cast(int, node.args[1]) 66 67 if axis < 0: 68 axis += node.meta["val"].dim() 69 70 if QCOM_AXIS_ORDER in node.meta: 71 axis = node.meta[QCOM_AXIS_ORDER].index(axis) 72 73 concat_op = PyQnnWrapper.PyQnnOpWrapper( 74 node.name, 75 QNN_OP_PACKAGE_NAME_QTI_AISW, 76 OpConcat.op_name, 77 ) 78 concat_op.AddInputTensors(list_of_tensor_wrappers) 79 concat_op.AddOutputTensors([output_tensor_wrapper]) 80 81 concat_op.AddScalarParam( 82 OpConcat.param_axis, 83 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 84 {QCOM_DATA: np.uint32(axis)}, 85 ) 86 87 return concat_op 88