1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6from typing import Dict 7 8import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 9 10import numpy as np 11import torch 12from executorch.backends.qualcomm.utils.constants import QCOM_DATA 13 14from .node_visitor import NodeVisitor, register_node_visitor 15from .qnn_constants import OpGather, QNN_OP_PACKAGE_NAME_QTI_AISW 16from .utils import get_parameter 17 18 19@register_node_visitor 20class Embedding(NodeVisitor): 21 target = ["aten.embedding.default"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 weight_node = node.args[0] 32 weight_tensor = get_parameter(weight_node, self.edge_program) 33 weight_tensor_wrapper = self.define_tensor( 34 weight_node, 35 weight_tensor, 36 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 37 nodes_to_wrappers, 38 is_input_tensor=True, 39 ) 40 41 indices_node = node.args[1] 42 indices_tensor = self.get_tensor(indices_node, node) 43 indices_tensor_wrapper = self.define_tensor( 44 indices_node, 45 indices_tensor, 46 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 47 nodes_to_wrappers, 48 is_input_tensor=True, 49 ) 50 51 gather_input_tensors = [weight_tensor_wrapper, indices_tensor_wrapper] 52 53 output_tensor = self.get_tensor(node, node) 54 output_tensor_wrapper = self.define_tensor( 55 node, 56 output_tensor, 57 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 58 nodes_to_wrappers, 59 is_input_tensor=False, 60 ) 61 gather_output_tensors = [output_tensor_wrapper] 62 63 gather_op = PyQnnWrapper.PyQnnOpWrapper( 64 node.name, 65 QNN_OP_PACKAGE_NAME_QTI_AISW, 66 OpGather.op_name, 67 ) 68 gather_op.AddInputTensors(gather_input_tensors) 69 gather_op.AddOutputTensors(gather_output_tensors) 70 71 # For now, default axis is zero. 72 gather_op.AddScalarParam( 73 OpGather.param_axis, 74 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_INT_32, 75 {QCOM_DATA: np.int32(0)}, 76 ) 77 78 return gather_op 79