1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict, List 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13from executorch.backends.qualcomm.utils.constants import QCOM_INSERTED_PERMUTE 14 15from .node_visitor import NodeVisitor, register_node_visitor 16from .qnn_constants import OpTranspose, QNN_OP_PACKAGE_NAME_QTI_AISW 17 18 19@register_node_visitor 20class TransposeVisitor(NodeVisitor): 21 target = ["aten.permute_copy.default"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 input_node = node.args[0] 32 permute_node = input_node if QCOM_INSERTED_PERMUTE in node.meta else node 33 input_tensor = self.get_tensor(input_node, permute_node) 34 input_tensor_wrapper = self.define_tensor( 35 input_node, 36 input_tensor, 37 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 38 nodes_to_wrappers, 39 is_input_tensor=True, 40 ) 41 42 # permutation 43 permute_order = cast(List[int], node.args[1]) 44 permute_order_shape = [len(permute_order)] 45 46 output_tensor = input_tensor.permute(permute_order) 47 output_tensor_wrapper = self.define_tensor( 48 node, 49 output_tensor, 50 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 51 nodes_to_wrappers, 52 is_input_tensor=False, 53 ) 54 55 transpose_op = PyQnnWrapper.PyQnnOpWrapper( 56 node.name, 57 QNN_OP_PACKAGE_NAME_QTI_AISW, 58 OpTranspose.op_name, 59 ) 60 61 # add input/output tensors 62 transpose_op.AddInputTensors([input_tensor_wrapper]) 63 transpose_op.AddOutputTensors([output_tensor_wrapper]) 64 65 transpose_op.AddTensorParam( 66 OpTranspose.param_perm, 67 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 68 len(permute_order_shape), 69 permute_order_shape, 70 np.array(permute_order, dtype=np.uint32), 71 True, 72 ) 73 return transpose_op 74