1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import cast, Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import numpy as np 11import torch 12 13from .node_visitor import NodeVisitor, register_node_visitor 14from .qnn_constants import OpStridedSlice, QNN_OP_PACKAGE_NAME_QTI_AISW 15 16 17@register_node_visitor 18class StrideSlice(NodeVisitor): 19 target = ["aten.slice_copy.Tensor"] 20 21 def __init__(self, *args) -> None: 22 super().__init__(*args) 23 24 def define_node( 25 self, 26 node: torch.fx.Node, 27 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 28 ) -> PyQnnWrapper.PyQnnOpWrapper: 29 input_node = node.args[0] 30 input_tensor = self.get_tensor(input_node, node) 31 tensor_type = PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE 32 33 input_tensor_wrapper = self.define_tensor( 34 input_node, 35 input_tensor, 36 tensor_type, 37 nodes_to_wrappers, 38 is_input_tensor=True, 39 ) 40 41 output_tensor = self.get_tensor(node, node) 42 output_tensor_wrapper = self.define_tensor( 43 node, 44 output_tensor, 45 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 46 nodes_to_wrappers, 47 is_input_tensor=False, 48 ) 49 50 dim = cast(int, node.args[1]) 51 if dim < 0: 52 dim = dim % len(input_tensor.shape) 53 start = cast(int, node.args[2]) 54 if start < 0: 55 start = start % input_tensor.shape[dim] 56 end = min(cast(int, node.args[3]), input_tensor.shape[dim]) 57 if end < 0: 58 end = end % input_tensor.shape[dim] 59 60 input_tensor_rank = len(input_tensor.shape) 61 ranges = [] 62 for i in range(input_tensor_rank): 63 if i == dim: 64 # find step 65 step = node.args[4] if len(node.args) > 4 else 1 66 ranges.extend([start, end, step]) 67 else: 68 ranges.extend([0, input_tensor.shape[i], 1]) 69 70 range_shape = [input_tensor_rank, 3] 71 72 stride_slice_op = PyQnnWrapper.PyQnnOpWrapper( 73 node.name, 74 QNN_OP_PACKAGE_NAME_QTI_AISW, 75 OpStridedSlice.op_name, 76 ) 77 stride_slice_op.AddInputTensors([input_tensor_wrapper]) 78 stride_slice_op.AddOutputTensors([output_tensor_wrapper]) 79 80 stride_slice_op.AddTensorParam( 81 OpStridedSlice.param_ranges, 82 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, 83 len(range_shape), 84 range_shape, 85 np.array(ranges, dtype=np.int32), 86 True, 87 ) 88 89 return stride_slice_op 90