1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import warnings 8from typing import Dict 9 10import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 11 12import numpy as np 13import torch 14from executorch.backends.qualcomm.utils.constants import QCOM_DATA 15 16from .node_visitor import NodeVisitor, register_node_visitor 17from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW 18from .utils import get_parameter 19 20 21@register_node_visitor 22class LayerNormVisitor(NodeVisitor): 23 target = ["aten.native_layer_norm.default"] 24 25 def __init__(self, *args) -> None: 26 super().__init__(*args) 27 28 def define_node( 29 self, 30 node: torch.fx.Node, 31 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 32 ) -> PyQnnWrapper.PyQnnOpWrapper: 33 input_node = node.args[0] 34 input_tensor = self.get_tensor(input_node, node) 35 input_tensor_wrapper = self.define_tensor( 36 input_node, 37 input_tensor, 38 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 39 nodes_to_wrappers, 40 is_input_tensor=True, 41 ) 42 43 normalized_shapes = node.args[1] 44 if ( 45 len(normalized_shapes) != 1 46 and normalized_shapes[0] != input_tensor.shape[-1] 47 ): 48 warnings.warn( 49 "[QNN Delegate Op Builder]: Only supports normalization with last input dimension.", 50 stacklevel=1, 51 ) 52 return 53 axis = [len(input_tensor.shape) - 1] 54 axis_shape = [len(axis)] 55 56 weight_node = node.args[2] 57 weight_tensor = get_parameter(weight_node, self.edge_program) 58 weight_tensor_wrapper = self.define_tensor( 59 weight_node, 60 weight_tensor, 61 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 62 nodes_to_wrappers, 63 is_input_tensor=False, 64 ) 65 66 bias_node = node.args[3] 67 bias_tensor = get_parameter(bias_node, self.edge_program) 68 bias_tensor_wrapper = self.define_tensor( 69 bias_node, 70 bias_tensor, 71 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 72 nodes_to_wrappers, 73 is_input_tensor=False, 74 ) 75 76 epsilon = node.args[4] 77 78 output_tensor = self.get_tensor(node, node, 0) 79 output_tensor_wrapper = self.define_tensor( 80 node, 81 output_tensor, 82 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 83 nodes_to_wrappers, 84 is_input_tensor=False, 85 ) 86 87 layer_norm_op = PyQnnWrapper.PyQnnOpWrapper( 88 node.name, 89 QNN_OP_PACKAGE_NAME_QTI_AISW, 90 OpLayerNorm.op_name, 91 ) 92 layer_norm_op.AddInputTensors( 93 [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] 94 ) 95 layer_norm_op.AddOutputTensors([output_tensor_wrapper]) 96 layer_norm_op.AddScalarParam( 97 OpLayerNorm.param_epsilon, 98 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, 99 {QCOM_DATA: np.float32(epsilon)}, 100 ) 101 layer_norm_op.AddTensorParam( 102 OpLayerNorm.param_axes, 103 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 104 len(axis_shape), 105 axis_shape, 106 np.array(axis, dtype=np.uint32), 107 True, 108 ) 109 110 return layer_norm_op 111