1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import math 7from typing import cast, Dict 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13from executorch.backends.qualcomm.utils.constants import QCOM_DATA 14 15from .node_visitor import NodeVisitor, register_node_visitor 16from .qnn_constants import OpStridedSlice, QNN_OP_PACKAGE_NAME_QTI_AISW 17 18 19@register_node_visitor 20class SelectCopy(NodeVisitor): 21 target = ["aten.select_copy.int", "aten.select.int"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 input_node = node.args[0] 32 input_tensor = self.get_tensor(input_node, node) 33 input_tensor_wrapper = self.define_tensor( 34 input_node, 35 input_tensor, 36 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 37 nodes_to_wrappers, 38 is_input_tensor=True, 39 ) 40 41 output_tensor = self.get_tensor(node, node) 42 output_tensor_wrapper = self.define_tensor( 43 node, 44 output_tensor, 45 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 46 nodes_to_wrappers, 47 is_input_tensor=False, 48 ) 49 50 dim = cast(int, node.args[1]) 51 if dim < 0: 52 dim = dim % len(input_tensor.shape) 53 index = cast(int, node.args[2]) % input_tensor.shape[dim] 54 55 input_tensor_rank = len(input_tensor.shape) 56 ranges = [] 57 for i in range(input_tensor_rank): 58 if i == dim: 59 ranges.extend([index, index, 1]) 60 else: 61 ranges.extend([0, input_tensor.shape[i], 1]) 62 63 range_shape = [input_tensor_rank, 3] 64 65 stride_slice_op = PyQnnWrapper.PyQnnOpWrapper( 66 node.name, 67 QNN_OP_PACKAGE_NAME_QTI_AISW, 68 OpStridedSlice.op_name, 69 ) 70 stride_slice_op.AddInputTensors([input_tensor_wrapper]) 71 stride_slice_op.AddOutputTensors([output_tensor_wrapper]) 72 73 stride_slice_op.AddTensorParam( 74 OpStridedSlice.param_ranges, 75 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, 76 len(range_shape), 77 range_shape, 78 np.array(ranges, dtype=np.int32), 79 True, 80 ) 81 82 stride_slice_op.AddScalarParam( 83 OpStridedSlice.param_shrink_axes, 84 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 85 {QCOM_DATA: np.uint32(math.pow(2, dim))}, 86 ) 87 88 return stride_slice_op 89