1*523fa7a6SAndroid Build Coastguard Worker# Contribution for More Operators 2*523fa7a6SAndroid Build Coastguard WorkerThank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of implementing operator builder to unblock yourself and land pull requests more efficiently. 3*523fa7a6SAndroid Build Coastguard Worker 4*523fa7a6SAndroid Build Coastguard Worker## Sections 5*523fa7a6SAndroid Build Coastguard Worker* [References](#references) 6*523fa7a6SAndroid Build Coastguard Worker* [Getting Started](#getting-started) 7*523fa7a6SAndroid Build Coastguard Worker * [Identify Unsupported Operator](#identify-unsupported-operator) 8*523fa7a6SAndroid Build Coastguard Worker * [Check Operator Spec](#check-operator-spec) 9*523fa7a6SAndroid Build Coastguard Worker * [Implementation](#implementation) 10*523fa7a6SAndroid Build Coastguard Worker * [Quantizer Annotation](#quantizer-annotation) 11*523fa7a6SAndroid Build Coastguard Worker* [Issues](#issues) 12*523fa7a6SAndroid Build Coastguard Worker* [Pull Requests](#pull-requests) 13*523fa7a6SAndroid Build Coastguard Worker 14*523fa7a6SAndroid Build Coastguard Worker## References 15*523fa7a6SAndroid Build Coastguard Worker### Qualcomm AI Engine Direct 16*523fa7a6SAndroid Build Coastguard Worker- [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html) 17*523fa7a6SAndroid Build Coastguard Worker- [Supported Operators in Backends](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/operations.html#backend-supplements) 18*523fa7a6SAndroid Build Coastguard Worker 19*523fa7a6SAndroid Build Coastguard Worker### PyTorch 20*523fa7a6SAndroid Build Coastguard Worker- [torch.nn Operator Definitions](https://pytorch.org/docs/stable/nn.html) 21*523fa7a6SAndroid Build Coastguard Worker- [torch.nn.functional Operator Definitions](https://pytorch.org/docs/stable/nn.functional.html) 22*523fa7a6SAndroid Build Coastguard Worker- [ATen Operator Definitions](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native) 23*523fa7a6SAndroid Build Coastguard Worker 24*523fa7a6SAndroid Build Coastguard Worker## Getting Started 25*523fa7a6SAndroid Build Coastguard Worker### Identify Unsupported Operator 26*523fa7a6SAndroid Build Coastguard WorkerConsider we're enabling following model: 27*523fa7a6SAndroid Build Coastguard Worker```python 28*523fa7a6SAndroid Build Coastguard Workerclass MyModel(torch.nn.Module): 29*523fa7a6SAndroid Build Coastguard Worker def __init__(self): 30*523fa7a6SAndroid Build Coastguard Worker super().__init__() 31*523fa7a6SAndroid Build Coastguard Worker self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6) 32*523fa7a6SAndroid Build Coastguard Worker self.linear = torch.nn.Linear(768, 100) 33*523fa7a6SAndroid Build Coastguard Worker 34*523fa7a6SAndroid Build Coastguard Worker def forward(self, x): 35*523fa7a6SAndroid Build Coastguard Worker return self.linear(self.layer_norm(x)) 36*523fa7a6SAndroid Build Coastguard Worker``` 37*523fa7a6SAndroid Build Coastguard WorkerAt the time we try to lower it with Qualcomm backend: 38*523fa7a6SAndroid Build Coastguard Worker```python 39*523fa7a6SAndroid Build Coastguard Workerfrom excutorch.examples.qualcomm.utils import build_executorch_binary 40*523fa7a6SAndroid Build Coastguard Worker 41*523fa7a6SAndroid Build Coastguard Workerbuild_executorch_binary( 42*523fa7a6SAndroid Build Coastguard Worker model=MyModel(), 43*523fa7a6SAndroid Build Coastguard Worker inputs=(torch.randn(200, 768),), 44*523fa7a6SAndroid Build Coastguard Worker soc_model="SM8650" 45*523fa7a6SAndroid Build Coastguard Worker file_name="my_model", 46*523fa7a6SAndroid Build Coastguard Worker dataset=None, 47*523fa7a6SAndroid Build Coastguard Worker) 48*523fa7a6SAndroid Build Coastguard Worker``` 49*523fa7a6SAndroid Build Coastguard WorkerAssume there is no `torch.nn.LayerNorm` support, you should see the following error logs: 50*523fa7a6SAndroid Build Coastguard Worker```bash 51*523fa7a6SAndroid Build Coastguard WorkerFile "/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 77, in is_node_supported 52*523fa7a6SAndroid Build Coastguard Worker op_wrapper = self.node_visitors[node.target.__name__].define_node( 53*523fa7a6SAndroid Build Coastguard WorkerKeyError: 'aten.native_layer_norm.default' 54*523fa7a6SAndroid Build Coastguard Worker``` 55*523fa7a6SAndroid Build Coastguard WorkerThis log comes straight to the point, there is no suitable conversion for delegating torch operator to Qualcomm AI Engine Direct. Where the `node_visitors` is a dictionary which maps operator target name with its implementation callback. The goal of this tutorial aims for helping you register the missing one.<br/> 56*523fa7a6SAndroid Build Coastguard WorkerThe very first step is to locate which operator type are we going to support. Sometimes the target name of operator might be obscure, following snippet could help you trace back by its call stack: 57*523fa7a6SAndroid Build Coastguard Worker```python 58*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.utils.utils import capture_program 59*523fa7a6SAndroid Build Coastguard Worker 60*523fa7a6SAndroid Build Coastguard Workerprog = capture_program(MyModel(), (torch.randn(200, 768),)) 61*523fa7a6SAndroid Build Coastguard Workerfor node in prog.exported_program.graph.nodes: 62*523fa7a6SAndroid Build Coastguard Worker if node.op == "call_function" and node.target.__name__ == 'aten.native_layer_norm.default': 63*523fa7a6SAndroid Build Coastguard Worker print(node.meta["source_fn_stack"]) 64*523fa7a6SAndroid Build Coastguard Worker``` 65*523fa7a6SAndroid Build Coastguard WorkerIt will provide more hint to the source PyTorch layer where the missing operator maps to: 66*523fa7a6SAndroid Build Coastguard Worker```bash 67*523fa7a6SAndroid Build Coastguard Worker[('l__self___layer_norm', <class 'torch.nn.modules.normalization.LayerNorm'>)] 68*523fa7a6SAndroid Build Coastguard Worker``` 69*523fa7a6SAndroid Build Coastguard Worker 70*523fa7a6SAndroid Build Coastguard Worker### Check Operator Spec 71*523fa7a6SAndroid Build Coastguard Worker- **Qualcomm AI Engine Direct**:<br/> 72*523fa7a6SAndroid Build Coastguard Worker You could collect information of `LayerNorm`'s IO via documents mentioned in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct): 73*523fa7a6SAndroid Build Coastguard Worker * inputs 74*523fa7a6SAndroid Build Coastguard Worker - in[0] - input activation / required 75*523fa7a6SAndroid Build Coastguard Worker - in[1] - gamma / optional 76*523fa7a6SAndroid Build Coastguard Worker - in[2] - beta / optional 77*523fa7a6SAndroid Build Coastguard Worker * parameters 78*523fa7a6SAndroid Build Coastguard Worker - "epsilon" / optional 79*523fa7a6SAndroid Build Coastguard Worker - "axes" / required 80*523fa7a6SAndroid Build Coastguard Worker * outputs 81*523fa7a6SAndroid Build Coastguard Worker - out[0] - output activation / required 82*523fa7a6SAndroid Build Coastguard Worker 83*523fa7a6SAndroid Build Coastguard Worker The required tensors must be provided for no default values were given inside QNN runtime, The order of IOs (`input activation`, `gamma`, `beta`) matters compared to parameters (`epsilon`, `axes`) who are recognized by literal value: 84*523fa7a6SAndroid Build Coastguard Worker ```c 85*523fa7a6SAndroid Build Coastguard Worker typedef struct { 86*523fa7a6SAndroid Build Coastguard Worker /// A human-readable name for the operation instance. 87*523fa7a6SAndroid Build Coastguard Worker const char* name; 88*523fa7a6SAndroid Build Coastguard Worker /// The name of the operation package to which this operation's type belongs. 89*523fa7a6SAndroid Build Coastguard Worker const char* packageName; 90*523fa7a6SAndroid Build Coastguard Worker /// The name of operation type (e.g. Conv2D). 91*523fa7a6SAndroid Build Coastguard Worker const char* typeName; 92*523fa7a6SAndroid Build Coastguard Worker /// The number of static parameters provided in the params array. 93*523fa7a6SAndroid Build Coastguard Worker uint32_t numOfParams; 94*523fa7a6SAndroid Build Coastguard Worker /// Array of operation parameters. 95*523fa7a6SAndroid Build Coastguard Worker Qnn_Param_t* params; 96*523fa7a6SAndroid Build Coastguard Worker /// The number of input tensors. 97*523fa7a6SAndroid Build Coastguard Worker uint32_t numOfInputs; 98*523fa7a6SAndroid Build Coastguard Worker /// Array of input tensors. 99*523fa7a6SAndroid Build Coastguard Worker Qnn_Tensor_t* inputTensors; 100*523fa7a6SAndroid Build Coastguard Worker /// The number of output tensors. 101*523fa7a6SAndroid Build Coastguard Worker uint32_t numOfOutputs; 102*523fa7a6SAndroid Build Coastguard Worker /// Array of output tensors. 103*523fa7a6SAndroid Build Coastguard Worker Qnn_Tensor_t* outputTensors; 104*523fa7a6SAndroid Build Coastguard Worker } Qnn_OpConfigV1_t; 105*523fa7a6SAndroid Build Coastguard Worker ``` 106*523fa7a6SAndroid Build Coastguard Worker This is a data structure used to check operator validity in QNN SDK. Inside validation process, tensors are retrieved sequentially and passed through a series of spec examinations while parameters are matched by their names: 107*523fa7a6SAndroid Build Coastguard Worker ```c 108*523fa7a6SAndroid Build Coastguard Worker typedef struct { 109*523fa7a6SAndroid Build Coastguard Worker /// Parameter type: scalar or tensor 110*523fa7a6SAndroid Build Coastguard Worker Qnn_ParamType_t paramType; 111*523fa7a6SAndroid Build Coastguard Worker /// Name of the parameter 112*523fa7a6SAndroid Build Coastguard Worker const char* name; 113*523fa7a6SAndroid Build Coastguard Worker 114*523fa7a6SAndroid Build Coastguard Worker union UNNAMED { 115*523fa7a6SAndroid Build Coastguard Worker /// Scalar parameter specification 116*523fa7a6SAndroid Build Coastguard Worker Qnn_Scalar_t scalarParam; 117*523fa7a6SAndroid Build Coastguard Worker /// Tensor parameter specification; tensors referred to must be STATIC. 118*523fa7a6SAndroid Build Coastguard Worker Qnn_Tensor_t tensorParam; 119*523fa7a6SAndroid Build Coastguard Worker }; 120*523fa7a6SAndroid Build Coastguard Worker } Qnn_Param_t; 121*523fa7a6SAndroid Build Coastguard Worker ``` 122*523fa7a6SAndroid Build Coastguard Worker The name value equals to the parameter name described in [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html), there are `epsilon`, `axes` for `LayerNorm` case.<br/> 123*523fa7a6SAndroid Build Coastguard Worker 124*523fa7a6SAndroid Build Coastguard Worker If you find it hard to correlate missing operator with documentation, this [table](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/SupportedOps.html) might be helpful for searching. In some cases, an exact match may not exist. Consider seeking for a math equivalent approach or notify maintainer for further analysis. 125*523fa7a6SAndroid Build Coastguard Worker 126*523fa7a6SAndroid Build Coastguard Worker- **PyTorch**:<br/> 127*523fa7a6SAndroid Build Coastguard Worker We could also read the IO spec from [function declaration](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/layer_norm.cpp) mentioned in [PyTorch Documentation](#pytorch): 128*523fa7a6SAndroid Build Coastguard Worker * inputs 129*523fa7a6SAndroid Build Coastguard Worker - in[0] - input activation / required 130*523fa7a6SAndroid Build Coastguard Worker - in[1] - normalized_shape / required 131*523fa7a6SAndroid Build Coastguard Worker - in[2] - weight_opt / optional 132*523fa7a6SAndroid Build Coastguard Worker - in[3] - bias_opt / optional 133*523fa7a6SAndroid Build Coastguard Worker - in[4] - eps / required 134*523fa7a6SAndroid Build Coastguard Worker 135*523fa7a6SAndroid Build Coastguard Worker Through comparing the [equation](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html), we could sort out the relevance of arguments (`gamma` / `beta` / `epsilon`) inside Qualcomm manual to PyTorch (`weight_opt` / `bias_opt` / `eps`). The unmatched parameter `axes` will have more discussions in the [implementation](#implementation) part. 136*523fa7a6SAndroid Build Coastguard Worker 137*523fa7a6SAndroid Build Coastguard Worker### Implementation 138*523fa7a6SAndroid Build Coastguard WorkerLet's start with adding new definition in `qnn_constant.py` for `LayerNorm` operator. 139*523fa7a6SAndroid Build Coastguard Worker```python 140*523fa7a6SAndroid Build Coastguard Worker@dataclass(init=False, frozen=True) 141*523fa7a6SAndroid Build Coastguard Workerclass OpHardSwish: 142*523fa7a6SAndroid Build Coastguard Worker ... 143*523fa7a6SAndroid Build Coastguard Worker 144*523fa7a6SAndroid Build Coastguard Worker# please insert it in alphabetically order 145*523fa7a6SAndroid Build Coastguard Worker@dataclass(init=False, frozen=True) 146*523fa7a6SAndroid Build Coastguard Workerclass OpLayerNorm: 147*523fa7a6SAndroid Build Coastguard Worker op_name: str = "LayerNorm" 148*523fa7a6SAndroid Build Coastguard Worker param_epsilon = "epsilon" 149*523fa7a6SAndroid Build Coastguard Worker param_axes = "axes" 150*523fa7a6SAndroid Build Coastguard Worker 151*523fa7a6SAndroid Build Coastguard Worker 152*523fa7a6SAndroid Build Coastguard Worker@dataclass(init=False, frozen=True) 153*523fa7a6SAndroid Build Coastguard Workerclass OpLogSoftmax: 154*523fa7a6SAndroid Build Coastguard Worker ... 155*523fa7a6SAndroid Build Coastguard Worker``` 156*523fa7a6SAndroid Build Coastguard WorkerThe conventions are: 157*523fa7a6SAndroid Build Coastguard Worker- op_name: string describing the operator 158*523fa7a6SAndroid Build Coastguard Worker- params_xxx: string for consumed parameters 159*523fa7a6SAndroid Build Coastguard Worker 160*523fa7a6SAndroid Build Coastguard WorkerThe content should have exact match with literal values mentioned in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct) or `QnnOpDef.h` under `$QNN_SDK_ROOT/include/QNN/`: 161*523fa7a6SAndroid Build Coastguard Worker```c 162*523fa7a6SAndroid Build Coastguard Worker#define QNN_OP_LAYER_NORM "LayerNorm" 163*523fa7a6SAndroid Build Coastguard Worker#define QNN_OP_LAYER_NORM_PARAM_EPSILON "epsilon" 164*523fa7a6SAndroid Build Coastguard Worker#define QNN_OP_LAYER_NORM_PARAM_AXES "axes" 165*523fa7a6SAndroid Build Coastguard Worker``` 166*523fa7a6SAndroid Build Coastguard Worker 167*523fa7a6SAndroid Build Coastguard WorkerNext, create a new file with name in snake case format (e.g. `op_layer_norm.py`) and import required modules (please check comments for getting the ideas of usage): 168*523fa7a6SAndroid Build Coastguard Worker```python 169*523fa7a6SAndroid Build Coastguard Worker# pybind interface for invoking QNN APIs 170*523fa7a6SAndroid Build Coastguard Workerimport executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper 171*523fa7a6SAndroid Build Coastguard Worker# tensors or other numerics will be shipped in numpy format 172*523fa7a6SAndroid Build Coastguard Workerimport numpy as np 173*523fa7a6SAndroid Build Coastguard Workerimport torch 174*523fa7a6SAndroid Build Coastguard Worker# common keywords of Qualcomm backend 175*523fa7a6SAndroid Build Coastguard Workerfrom executorch.backends.qualcomm.utils.constants import QCOM_DATA 176*523fa7a6SAndroid Build Coastguard Worker# op builder will inherit NodeVisitor and have its own implementation 177*523fa7a6SAndroid Build Coastguard Worker# register_node_visitor for book-keeping the dictionary of target name v.s. callback 178*523fa7a6SAndroid Build Coastguard Workerfrom .node_visitor import NodeVisitor, register_node_visitor 179*523fa7a6SAndroid Build Coastguard Worker# the definitions required to build operator in QNN 180*523fa7a6SAndroid Build Coastguard Workerfrom .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW 181*523fa7a6SAndroid Build Coastguard Worker# utility to get parameter value when creating tensor in QNN 182*523fa7a6SAndroid Build Coastguard Workerfrom .utils import get_parameter 183*523fa7a6SAndroid Build Coastguard Worker``` 184*523fa7a6SAndroid Build Coastguard WorkerStart with function declaration as: 185*523fa7a6SAndroid Build Coastguard Worker```python 186*523fa7a6SAndroid Build Coastguard Worker@register_node_visitor 187*523fa7a6SAndroid Build Coastguard Workerclass LayerNormVisitor(NodeVisitor): 188*523fa7a6SAndroid Build Coastguard Worker target = ["aten.native_layer_norm.default"] 189*523fa7a6SAndroid Build Coastguard Worker 190*523fa7a6SAndroid Build Coastguard Worker def __init__(self, *args) -> None: 191*523fa7a6SAndroid Build Coastguard Worker super().__init__(*args) 192*523fa7a6SAndroid Build Coastguard Worker 193*523fa7a6SAndroid Build Coastguard Worker def define_node( 194*523fa7a6SAndroid Build Coastguard Worker self, 195*523fa7a6SAndroid Build Coastguard Worker node: torch.fx.Node, 196*523fa7a6SAndroid Build Coastguard Worker nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], 197*523fa7a6SAndroid Build Coastguard Worker ) -> PyQnnWrapper.PyQnnOpWrapper: 198*523fa7a6SAndroid Build Coastguard Worker``` 199*523fa7a6SAndroid Build Coastguard WorkerIt's mandatory to have `target` member in list form, since there would have multiple targets map to the same implementation. e.g. `aten.leaky_relu.default`, `aten.prelu.default` have similar equations but only differ in negative slope.<br/> 200*523fa7a6SAndroid Build Coastguard WorkerThe `nodes_to_wrappers` is a dictionary maintaining relationship between graph node and its output tensor. `nodes_to_wrappers` acts as an memo for not creating tensor objects to nodes that have already been traversed.<br/> 201*523fa7a6SAndroid Build Coastguard Worker 202*523fa7a6SAndroid Build Coastguard WorkerNow, we can start to fill in function body step by step: 203*523fa7a6SAndroid Build Coastguard Worker1. Define input activation tensors: 204*523fa7a6SAndroid Build Coastguard Worker ```python 205*523fa7a6SAndroid Build Coastguard Worker input_node = node.args[0] 206*523fa7a6SAndroid Build Coastguard Worker input_tensor = self.get_tensor(input_node, node) 207*523fa7a6SAndroid Build Coastguard Worker input_tensor_wrapper = self.define_tensor( 208*523fa7a6SAndroid Build Coastguard Worker input_node, 209*523fa7a6SAndroid Build Coastguard Worker input_tensor, 210*523fa7a6SAndroid Build Coastguard Worker PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 211*523fa7a6SAndroid Build Coastguard Worker nodes_to_wrappers, 212*523fa7a6SAndroid Build Coastguard Worker is_input_tensor=True, 213*523fa7a6SAndroid Build Coastguard Worker ) 214*523fa7a6SAndroid Build Coastguard Worker ``` 215*523fa7a6SAndroid Build Coastguard Worker Through the information in [Check Operator Spec](#check-operator-spec) section, we could easily extract the desired nodes.<br/> 216*523fa7a6SAndroid Build Coastguard Worker The `get_tensor` method is responsible for retrieving torch tensor in correct axis order if `layout_transform` pass happened to apply.<br/> 217*523fa7a6SAndroid Build Coastguard Worker The `define_tensor` method is for generating tensor object for QNN API and will be memorized by aforementioned `node_to_wrappers`.<br/> 218*523fa7a6SAndroid Build Coastguard Worker And yet, there are arguments worth for addressing more: 219*523fa7a6SAndroid Build Coastguard Worker - **node**: current graph node 220*523fa7a6SAndroid Build Coastguard Worker - **tensor**: torch tensor emitted by node 221*523fa7a6SAndroid Build Coastguard Worker - **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters 222*523fa7a6SAndroid Build Coastguard Worker - **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN) 223*523fa7a6SAndroid Build Coastguard Worker - **is_input_tensor**: flag to tell if current tensor is input activation or parameter, which is important for fixed point mixed-precision to work properly 224*523fa7a6SAndroid Build Coastguard Worker - **node_name**: (optional) tensor name for user to specify 225*523fa7a6SAndroid Build Coastguard Worker - **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object 226*523fa7a6SAndroid Build Coastguard Worker 227*523fa7a6SAndroid Build Coastguard Worker2. Define input gamma / beta tensors: 228*523fa7a6SAndroid Build Coastguard Worker ```python 229*523fa7a6SAndroid Build Coastguard Worker weight_node = node.args[2] 230*523fa7a6SAndroid Build Coastguard Worker weight_tensor = get_parameter(weight_node, self.edge_program) 231*523fa7a6SAndroid Build Coastguard Worker weight_tensor_wrapper = self.define_tensor( 232*523fa7a6SAndroid Build Coastguard Worker weight_node, 233*523fa7a6SAndroid Build Coastguard Worker weight_tensor, 234*523fa7a6SAndroid Build Coastguard Worker PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 235*523fa7a6SAndroid Build Coastguard Worker nodes_to_wrappers, 236*523fa7a6SAndroid Build Coastguard Worker is_input_tensor=False, 237*523fa7a6SAndroid Build Coastguard Worker ) 238*523fa7a6SAndroid Build Coastguard Worker 239*523fa7a6SAndroid Build Coastguard Worker bias_node = node.args[3] 240*523fa7a6SAndroid Build Coastguard Worker bias_tensor = get_parameter(bias_node, self.edge_program) 241*523fa7a6SAndroid Build Coastguard Worker bias_tensor_wrapper = self.define_tensor( 242*523fa7a6SAndroid Build Coastguard Worker bias_node, 243*523fa7a6SAndroid Build Coastguard Worker bias_tensor, 244*523fa7a6SAndroid Build Coastguard Worker PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, 245*523fa7a6SAndroid Build Coastguard Worker nodes_to_wrappers, 246*523fa7a6SAndroid Build Coastguard Worker is_input_tensor=False, 247*523fa7a6SAndroid Build Coastguard Worker ) 248*523fa7a6SAndroid Build Coastguard Worker ``` 249*523fa7a6SAndroid Build Coastguard Worker The logic should be similar and straightforward. Please carefully set arguments `tensor_type`, `is_input_tensor` according to tensors' property. 250*523fa7a6SAndroid Build Coastguard Worker 251*523fa7a6SAndroid Build Coastguard Worker3. Define parameters: 252*523fa7a6SAndroid Build Coastguard Worker ```python 253*523fa7a6SAndroid Build Coastguard Worker normalized_shapes = node.args[1] 254*523fa7a6SAndroid Build Coastguard Worker if len(normalized_shapes) != 1: 255*523fa7a6SAndroid Build Coastguard Worker print("QNN only supports normalized output with rank 1") 256*523fa7a6SAndroid Build Coastguard Worker return 257*523fa7a6SAndroid Build Coastguard Worker 258*523fa7a6SAndroid Build Coastguard Worker axes = [len(input_tensor.shape) - 1] 259*523fa7a6SAndroid Build Coastguard Worker axes_shape = [len(axes)] 260*523fa7a6SAndroid Build Coastguard Worker epsilon = node.args[4] 261*523fa7a6SAndroid Build Coastguard Worker ``` 262*523fa7a6SAndroid Build Coastguard Worker Here you can see the constraint introduced by Qualcomm AI Engine Direct. Unlike PyTorch's LayerNorm operator, QNN can only normalize input into 1-D tensor. Therefore we will have log to remind user and return the program directly, this gesture will be considered as validation failure in partitioner and will fallback this operator to CPU.<br/> 263*523fa7a6SAndroid Build Coastguard Worker When passing tensor type parameters via pybind interface, it's also required to ship extra information like tensor shape in list form. e.g. `axes_shape = [len(axes)]`. More details will be provided in coming steps. 264*523fa7a6SAndroid Build Coastguard Worker 265*523fa7a6SAndroid Build Coastguard Worker4. Define output tensor: 266*523fa7a6SAndroid Build Coastguard Worker ```python 267*523fa7a6SAndroid Build Coastguard Worker output_tensor = self.get_tensor(node, node, 0) 268*523fa7a6SAndroid Build Coastguard Worker output_tensor_wrapper = self.define_tensor( 269*523fa7a6SAndroid Build Coastguard Worker node, 270*523fa7a6SAndroid Build Coastguard Worker output_tensor, 271*523fa7a6SAndroid Build Coastguard Worker PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, 272*523fa7a6SAndroid Build Coastguard Worker nodes_to_wrappers, 273*523fa7a6SAndroid Build Coastguard Worker is_input_tensor=False, 274*523fa7a6SAndroid Build Coastguard Worker ) 275*523fa7a6SAndroid Build Coastguard Worker ``` 276*523fa7a6SAndroid Build Coastguard Worker Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method. 277*523fa7a6SAndroid Build Coastguard Worker 278*523fa7a6SAndroid Build Coastguard Worker5. Generate operator object in QNN graph: 279*523fa7a6SAndroid Build Coastguard Worker ```python 280*523fa7a6SAndroid Build Coastguard Worker layer_norm_op = PyQnnWrapper.PyQnnOpWrapper( 281*523fa7a6SAndroid Build Coastguard Worker node.name, 282*523fa7a6SAndroid Build Coastguard Worker QNN_OP_PACKAGE_NAME_QTI_AISW, 283*523fa7a6SAndroid Build Coastguard Worker OpLayerNorm.op_name, 284*523fa7a6SAndroid Build Coastguard Worker ) 285*523fa7a6SAndroid Build Coastguard Worker ``` 286*523fa7a6SAndroid Build Coastguard Worker 287*523fa7a6SAndroid Build Coastguard Worker6. Pass IO tensors to operator object: 288*523fa7a6SAndroid Build Coastguard Worker ```python 289*523fa7a6SAndroid Build Coastguard Worker layer_norm_op.AddInputTensors( 290*523fa7a6SAndroid Build Coastguard Worker [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] 291*523fa7a6SAndroid Build Coastguard Worker ) 292*523fa7a6SAndroid Build Coastguard Worker layer_norm_op.AddOutputTensors([output_tensor_wrapper]) 293*523fa7a6SAndroid Build Coastguard Worker ``` 294*523fa7a6SAndroid Build Coastguard Worker The IO tensor objects created before are gathered up and shipped to operator object. 295*523fa7a6SAndroid Build Coastguard Worker 296*523fa7a6SAndroid Build Coastguard Worker7. Pass parameters to operator object: 297*523fa7a6SAndroid Build Coastguard Worker ```python 298*523fa7a6SAndroid Build Coastguard Worker layer_norm_op.AddScalarParam( 299*523fa7a6SAndroid Build Coastguard Worker OpLayerNorm.param_epsilon, 300*523fa7a6SAndroid Build Coastguard Worker PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, 301*523fa7a6SAndroid Build Coastguard Worker {QCOM_DATA: np.float32(epsilon)}, 302*523fa7a6SAndroid Build Coastguard Worker ) 303*523fa7a6SAndroid Build Coastguard Worker layer_norm_op.AddTensorParam( 304*523fa7a6SAndroid Build Coastguard Worker OpLayerNorm.param_axes, 305*523fa7a6SAndroid Build Coastguard Worker PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, 306*523fa7a6SAndroid Build Coastguard Worker len(axis_shape), 307*523fa7a6SAndroid Build Coastguard Worker axis_shape, 308*523fa7a6SAndroid Build Coastguard Worker np.array(axis, dtype=np.uint32), 309*523fa7a6SAndroid Build Coastguard Worker True, 310*523fa7a6SAndroid Build Coastguard Worker ) 311*523fa7a6SAndroid Build Coastguard Worker ``` 312*523fa7a6SAndroid Build Coastguard Worker By checking the `Shape` property of parameter in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct), it should be clear which API to be used. e.g.: 313*523fa7a6SAndroid Build Coastguard Worker - "epsilon" > __Shape__: scalar 314*523fa7a6SAndroid Build Coastguard Worker - "axes" > __Shape__: 1D of shape[M] 315*523fa7a6SAndroid Build Coastguard Worker 316*523fa7a6SAndroid Build Coastguard Worker The function signature of AddScalarParam is: 317*523fa7a6SAndroid Build Coastguard Worker - **name**: string maps to the operator name in Qualcomm AI Engine Direct manual 318*523fa7a6SAndroid Build Coastguard Worker - **data_type**: type compatible with QNN SDK, e.g. `QNN_DATATYPE_FLOAT_32`, `QNN_DATATYPE_UINT_32`, etc. 319*523fa7a6SAndroid Build Coastguard Worker - **attr**: dictionary for shipping data, currently only `QCOM_DATA` key is used 320*523fa7a6SAndroid Build Coastguard Worker 321*523fa7a6SAndroid Build Coastguard Worker The function signature of AddTensorParam is: 322*523fa7a6SAndroid Build Coastguard Worker - **name**: string maps to the operator name in Qualcomm AI Engine Direct manual 323*523fa7a6SAndroid Build Coastguard Worker - **data_type**: type compatible with QNN SDK, e.g. `QNN_DATATYPE_FLOAT_32`, `QNN_DATATYPE_UINT_32`, etc. 324*523fa7a6SAndroid Build Coastguard Worker - **rank**: dimensions of tensor 325*523fa7a6SAndroid Build Coastguard Worker - **dims**: shape of tensor 326*523fa7a6SAndroid Build Coastguard Worker - **data**: tesnor data 327*523fa7a6SAndroid Build Coastguard Worker - **copy_data**: user should specify to True for constant parameters 328*523fa7a6SAndroid Build Coastguard Worker 329*523fa7a6SAndroid Build Coastguard Worker8. Last, return operator object for partitioner to conduct validation: 330*523fa7a6SAndroid Build Coastguard Worker ```python 331*523fa7a6SAndroid Build Coastguard Worker return layer_norm_op 332*523fa7a6SAndroid Build Coastguard Worker ``` 333*523fa7a6SAndroid Build Coastguard Worker Also update the `__init__.py` for `register_node_visitor` to work properly: 334*523fa7a6SAndroid Build Coastguard Worker ```python 335*523fa7a6SAndroid Build Coastguard Worker from . import ( 336*523fa7a6SAndroid Build Coastguard Worker ... 337*523fa7a6SAndroid Build Coastguard Worker op_index_put, 338*523fa7a6SAndroid Build Coastguard Worker # please insert codes in alphabetical order 339*523fa7a6SAndroid Build Coastguard Worker op_layer_norm, 340*523fa7a6SAndroid Build Coastguard Worker op_linear, 341*523fa7a6SAndroid Build Coastguard Worker ... 342*523fa7a6SAndroid Build Coastguard Worker ) 343*523fa7a6SAndroid Build Coastguard Worker 344*523fa7a6SAndroid Build Coastguard Worker __all__ = [ 345*523fa7a6SAndroid Build Coastguard Worker ... 346*523fa7a6SAndroid Build Coastguard Worker op_index_put, 347*523fa7a6SAndroid Build Coastguard Worker # please insert codes in alphabetical order 348*523fa7a6SAndroid Build Coastguard Worker op_layer_norm, 349*523fa7a6SAndroid Build Coastguard Worker op_linear, 350*523fa7a6SAndroid Build Coastguard Worker ... 351*523fa7a6SAndroid Build Coastguard Worker ] 352*523fa7a6SAndroid Build Coastguard Worker ``` 353*523fa7a6SAndroid Build Coastguard Worker 354*523fa7a6SAndroid Build Coastguard Worker### Quantizer Annotation 355*523fa7a6SAndroid Build Coastguard WorkerThe operator now should be functional for Qualcomm backends. For operator to work in fixed-precision, we should also make `QnnQuantizer` to correctly insert observers for recording calibrated encodings. Please read more on the [Quantization Annotation Tutorial](../quantizer//README.md). 356*523fa7a6SAndroid Build Coastguard Worker 357*523fa7a6SAndroid Build Coastguard Worker## Issues 358*523fa7a6SAndroid Build Coastguard WorkerPlease refer to the [issue section](../README.md#issues) for more information. 359*523fa7a6SAndroid Build Coastguard Worker 360*523fa7a6SAndroid Build Coastguard Worker## Pull Requests 361*523fa7a6SAndroid Build Coastguard WorkerPlease refer to the [PR section](../README.md#pull-requests) for more information. 362