1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import cast, Dict, List 8 9import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 10 11import numpy as np 12import torch 13from executorch.backends.qualcomm.utils.constants import QCOM_AXIS_ORDER, QCOM_DATA 14 15from .node_visitor import NodeVisitor, register_node_visitor 16from .qnn_constants import OpReduceMean, QNN_OP_PACKAGE_NAME_QTI_AISW 17 18 19@register_node_visitor 20class MeanDim(NodeVisitor): 21 target = ["aten.mean.dim"] 22 23 def __init__(self, *args) -> None: 24 super().__init__(*args) 25 26 def define_node( 27 self, 28 node: torch.fx.Node, 29 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 30 ) -> PyQnnWrapper.PyQnnOpWrapper: 31 input_node = node.args[0] 32 input_tensor = self.get_tensor(input_node, node) 33 input_tensor_wrapper = self.define_tensor( 34 input_node, 35 input_tensor, 36 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 37 nodes_to_wrappers, 38 is_input_tensor=True, 39 ) 40 41 # mean dims and keep dims 42 mean_dims = cast(List[int], node.args[1]) 43 mean_dims = [ 44 mean_dim % len(input_node.meta["val"].shape) for mean_dim in mean_dims 45 ] 46 if QCOM_AXIS_ORDER in node.meta: 47 mean_dims = [ 48 node.meta[QCOM_AXIS_ORDER].index(mean_dim) for mean_dim in mean_dims 49 ] 50 mean_dims_shape = [len(mean_dims)] 51 52 output_tensor = self.get_tensor(node, node) 53 output_tensor_wrapper = self.define_tensor( 54 node, 55 output_tensor, 56 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 57 nodes_to_wrappers, 58 is_input_tensor=False, 59 ) 60 61 reduce_mean_op = PyQnnWrapper.PyQnnOpWrapper( 62 node.name, 63 QNN_OP_PACKAGE_NAME_QTI_AISW, 64 OpReduceMean.op_name, 65 ) 66 reduce_mean_op.AddInputTensors([input_tensor_wrapper]) 67 reduce_mean_op.AddOutputTensors([output_tensor_wrapper]) 68 reduce_mean_op.AddTensorParam( 69 OpReduceMean.param_axes, 70 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 71 len(mean_dims_shape), 72 mean_dims_shape, 73 np.array(mean_dims, dtype=np.uint32), 74 True, 75 ) 76 if len(node.args) > 2: 77 keep_dims = cast(bool, node.args[2]) 78 reduce_mean_op.AddScalarParam( 79 OpReduceMean.param_keep_dims, 80 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_BOOL_8, 81 {QCOM_DATA: keep_dims}, 82 ) 83 84 return reduce_mean_op 85