README.md
1# Contribution for More Operators
2Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of implementing operator builder to unblock yourself and land pull requests more efficiently.
3
4## Sections
5* [References](#references)
6* [Getting Started](#getting-started)
7 * [Identify Unsupported Operator](#identify-unsupported-operator)
8 * [Check Operator Spec](#check-operator-spec)
9 * [Implementation](#implementation)
10 * [Quantizer Annotation](#quantizer-annotation)
11* [Issues](#issues)
12* [Pull Requests](#pull-requests)
13
14## References
15### Qualcomm AI Engine Direct
16- [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html)
17- [Supported Operators in Backends](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/operations.html#backend-supplements)
18
19### PyTorch
20- [torch.nn Operator Definitions](https://pytorch.org/docs/stable/nn.html)
21- [torch.nn.functional Operator Definitions](https://pytorch.org/docs/stable/nn.functional.html)
22- [ATen Operator Definitions](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native)
23
24## Getting Started
25### Identify Unsupported Operator
26Consider we're enabling following model:
27```python
28class MyModel(torch.nn.Module):
29 def __init__(self):
30 super().__init__()
31 self.layer_norm = torch.nn.LayerNorm([768], eps=1e-6)
32 self.linear = torch.nn.Linear(768, 100)
33
34 def forward(self, x):
35 return self.linear(self.layer_norm(x))
36```
37At the time we try to lower it with Qualcomm backend:
38```python
39from excutorch.examples.qualcomm.utils import build_executorch_binary
40
41build_executorch_binary(
42 model=MyModel(),
43 inputs=(torch.randn(200, 768),),
44 soc_model="SM8650"
45 file_name="my_model",
46 dataset=None,
47)
48```
49Assume there is no `torch.nn.LayerNorm` support, you should see the following error logs:
50```bash
51File "/executorch/backends/qualcomm/partition/qnn_partitioner.py", line 77, in is_node_supported
52 op_wrapper = self.node_visitors[node.target.__name__].define_node(
53KeyError: 'aten.native_layer_norm.default'
54```
55This log comes straight to the point, there is no suitable conversion for delegating torch operator to Qualcomm AI Engine Direct. Where the `node_visitors` is a dictionary which maps operator target name with its implementation callback. The goal of this tutorial aims for helping you register the missing one.<br/>
56The very first step is to locate which operator type are we going to support. Sometimes the target name of operator might be obscure, following snippet could help you trace back by its call stack:
57```python
58from executorch.backends.qualcomm.utils.utils import capture_program
59
60prog = capture_program(MyModel(), (torch.randn(200, 768),))
61for node in prog.exported_program.graph.nodes:
62 if node.op == "call_function" and node.target.__name__ == 'aten.native_layer_norm.default':
63 print(node.meta["source_fn_stack"])
64```
65It will provide more hint to the source PyTorch layer where the missing operator maps to:
66```bash
67[('l__self___layer_norm', <class 'torch.nn.modules.normalization.LayerNorm'>)]
68```
69
70### Check Operator Spec
71- **Qualcomm AI Engine Direct**:<br/>
72 You could collect information of `LayerNorm`'s IO via documents mentioned in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct):
73 * inputs
74 - in[0] - input activation / required
75 - in[1] - gamma / optional
76 - in[2] - beta / optional
77 * parameters
78 - "epsilon" / optional
79 - "axes" / required
80 * outputs
81 - out[0] - output activation / required
82
83 The required tensors must be provided for no default values were given inside QNN runtime, The order of IOs (`input activation`, `gamma`, `beta`) matters compared to parameters (`epsilon`, `axes`) who are recognized by literal value:
84 ```c
85 typedef struct {
86 /// A human-readable name for the operation instance.
87 const char* name;
88 /// The name of the operation package to which this operation's type belongs.
89 const char* packageName;
90 /// The name of operation type (e.g. Conv2D).
91 const char* typeName;
92 /// The number of static parameters provided in the params array.
93 uint32_t numOfParams;
94 /// Array of operation parameters.
95 Qnn_Param_t* params;
96 /// The number of input tensors.
97 uint32_t numOfInputs;
98 /// Array of input tensors.
99 Qnn_Tensor_t* inputTensors;
100 /// The number of output tensors.
101 uint32_t numOfOutputs;
102 /// Array of output tensors.
103 Qnn_Tensor_t* outputTensors;
104 } Qnn_OpConfigV1_t;
105 ```
106 This is a data structure used to check operator validity in QNN SDK. Inside validation process, tensors are retrieved sequentially and passed through a series of spec examinations while parameters are matched by their names:
107 ```c
108 typedef struct {
109 /// Parameter type: scalar or tensor
110 Qnn_ParamType_t paramType;
111 /// Name of the parameter
112 const char* name;
113
114 union UNNAMED {
115 /// Scalar parameter specification
116 Qnn_Scalar_t scalarParam;
117 /// Tensor parameter specification; tensors referred to must be STATIC.
118 Qnn_Tensor_t tensorParam;
119 };
120 } Qnn_Param_t;
121 ```
122 The name value equals to the parameter name described in [Operator Definitions](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/MasterOpDef.html), there are `epsilon`, `axes` for `LayerNorm` case.<br/>
123
124 If you find it hard to correlate missing operator with documentation, this [table](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/SupportedOps.html) might be helpful for searching. In some cases, an exact match may not exist. Consider seeking for a math equivalent approach or notify maintainer for further analysis.
125
126- **PyTorch**:<br/>
127 We could also read the IO spec from [function declaration](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/layer_norm.cpp) mentioned in [PyTorch Documentation](#pytorch):
128 * inputs
129 - in[0] - input activation / required
130 - in[1] - normalized_shape / required
131 - in[2] - weight_opt / optional
132 - in[3] - bias_opt / optional
133 - in[4] - eps / required
134
135 Through comparing the [equation](https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html), we could sort out the relevance of arguments (`gamma` / `beta` / `epsilon`) inside Qualcomm manual to PyTorch (`weight_opt` / `bias_opt` / `eps`). The unmatched parameter `axes` will have more discussions in the [implementation](#implementation) part.
136
137### Implementation
138Let's start with adding new definition in `qnn_constant.py` for `LayerNorm` operator.
139```python
140@dataclass(init=False, frozen=True)
141class OpHardSwish:
142 ...
143
144# please insert it in alphabetically order
145@dataclass(init=False, frozen=True)
146class OpLayerNorm:
147 op_name: str = "LayerNorm"
148 param_epsilon = "epsilon"
149 param_axes = "axes"
150
151
152@dataclass(init=False, frozen=True)
153class OpLogSoftmax:
154 ...
155```
156The conventions are:
157- op_name: string describing the operator
158- params_xxx: string for consumed parameters
159
160The content should have exact match with literal values mentioned in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct) or `QnnOpDef.h` under `$QNN_SDK_ROOT/include/QNN/`:
161```c
162#define QNN_OP_LAYER_NORM "LayerNorm"
163#define QNN_OP_LAYER_NORM_PARAM_EPSILON "epsilon"
164#define QNN_OP_LAYER_NORM_PARAM_AXES "axes"
165```
166
167Next, create a new file with name in snake case format (e.g. `op_layer_norm.py`) and import required modules (please check comments for getting the ideas of usage):
168```python
169# pybind interface for invoking QNN APIs
170import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper
171# tensors or other numerics will be shipped in numpy format
172import numpy as np
173import torch
174# common keywords of Qualcomm backend
175from executorch.backends.qualcomm.utils.constants import QCOM_DATA
176# op builder will inherit NodeVisitor and have its own implementation
177# register_node_visitor for book-keeping the dictionary of target name v.s. callback
178from .node_visitor import NodeVisitor, register_node_visitor
179# the definitions required to build operator in QNN
180from .qnn_constants import OpLayerNorm, QNN_OP_PACKAGE_NAME_QTI_AISW
181# utility to get parameter value when creating tensor in QNN
182from .utils import get_parameter
183```
184Start with function declaration as:
185```python
186@register_node_visitor
187class LayerNormVisitor(NodeVisitor):
188 target = ["aten.native_layer_norm.default"]
189
190 def __init__(self, *args) -> None:
191 super().__init__(*args)
192
193 def define_node(
194 self,
195 node: torch.fx.Node,
196 nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper],
197 ) -> PyQnnWrapper.PyQnnOpWrapper:
198```
199It's mandatory to have `target` member in list form, since there would have multiple targets map to the same implementation. e.g. `aten.leaky_relu.default`, `aten.prelu.default` have similar equations but only differ in negative slope.<br/>
200The `nodes_to_wrappers` is a dictionary maintaining relationship between graph node and its output tensor. `nodes_to_wrappers` acts as an memo for not creating tensor objects to nodes that have already been traversed.<br/>
201
202Now, we can start to fill in function body step by step:
2031. Define input activation tensors:
204 ```python
205 input_node = node.args[0]
206 input_tensor = self.get_tensor(input_node, node)
207 input_tensor_wrapper = self.define_tensor(
208 input_node,
209 input_tensor,
210 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
211 nodes_to_wrappers,
212 is_input_tensor=True,
213 )
214 ```
215 Through the information in [Check Operator Spec](#check-operator-spec) section, we could easily extract the desired nodes.<br/>
216 The `get_tensor` method is responsible for retrieving torch tensor in correct axis order if `layout_transform` pass happened to apply.<br/>
217 The `define_tensor` method is for generating tensor object for QNN API and will be memorized by aforementioned `node_to_wrappers`.<br/>
218 And yet, there are arguments worth for addressing more:
219 - **node**: current graph node
220 - **tensor**: torch tensor emitted by node
221 - **tensor_type**: type compatible with QNN SDK, oftenly use `QNN_TENSOR_TYPE_NATIVE` for intermediate outputs and `QNN_TENSOR_TYPE_STATIC` for constant parameters
222 - **nodes_to_wrappers**: dictionary of graph node and its output tensor (note: the tensor here is not a torch tensor but a wrapped object for QNN)
223 - **is_input_tensor**: flag to tell if current tensor is input activation or parameter, which is important for fixed point mixed-precision to work properly
224 - **node_name**: (optional) tensor name for user to specify
225 - **wrapper_idx**: (optional) defaults to zero if node is not a tuple, otherwise it acts as an indexer to output tensors. e.g. when slicing input tensor into multiple outputs, `wrapper_idx` is necessary for getting correct wrapped tensor object
226
2272. Define input gamma / beta tensors:
228 ```python
229 weight_node = node.args[2]
230 weight_tensor = get_parameter(weight_node, self.edge_program)
231 weight_tensor_wrapper = self.define_tensor(
232 weight_node,
233 weight_tensor,
234 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
235 nodes_to_wrappers,
236 is_input_tensor=False,
237 )
238
239 bias_node = node.args[3]
240 bias_tensor = get_parameter(bias_node, self.edge_program)
241 bias_tensor_wrapper = self.define_tensor(
242 bias_node,
243 bias_tensor,
244 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC,
245 nodes_to_wrappers,
246 is_input_tensor=False,
247 )
248 ```
249 The logic should be similar and straightforward. Please carefully set arguments `tensor_type`, `is_input_tensor` according to tensors' property.
250
2513. Define parameters:
252 ```python
253 normalized_shapes = node.args[1]
254 if len(normalized_shapes) != 1:
255 print("QNN only supports normalized output with rank 1")
256 return
257
258 axes = [len(input_tensor.shape) - 1]
259 axes_shape = [len(axes)]
260 epsilon = node.args[4]
261 ```
262 Here you can see the constraint introduced by Qualcomm AI Engine Direct. Unlike PyTorch's LayerNorm operator, QNN can only normalize input into 1-D tensor. Therefore we will have log to remind user and return the program directly, this gesture will be considered as validation failure in partitioner and will fallback this operator to CPU.<br/>
263 When passing tensor type parameters via pybind interface, it's also required to ship extra information like tensor shape in list form. e.g. `axes_shape = [len(axes)]`. More details will be provided in coming steps.
264
2654. Define output tensor:
266 ```python
267 output_tensor = self.get_tensor(node, node, 0)
268 output_tensor_wrapper = self.define_tensor(
269 node,
270 output_tensor,
271 PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE,
272 nodes_to_wrappers,
273 is_input_tensor=False,
274 )
275 ```
276 Althought the input / output activations might map to the graph IOs (a.k.a. user inputs / outputs) with corresponding type `QNN_TENSOR_TYPE_APP_READ` / `QNN_TENSOR_TYPE_APP_WRITE`. Users are still expected to have `QNN_TENSOR_TYPE_NATIVE` for all nodes' IOs and leave the detection logic handled inside `define_tensor` method.
277
2785. Generate operator object in QNN graph:
279 ```python
280 layer_norm_op = PyQnnWrapper.PyQnnOpWrapper(
281 node.name,
282 QNN_OP_PACKAGE_NAME_QTI_AISW,
283 OpLayerNorm.op_name,
284 )
285 ```
286
2876. Pass IO tensors to operator object:
288 ```python
289 layer_norm_op.AddInputTensors(
290 [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper]
291 )
292 layer_norm_op.AddOutputTensors([output_tensor_wrapper])
293 ```
294 The IO tensor objects created before are gathered up and shipped to operator object.
295
2967. Pass parameters to operator object:
297 ```python
298 layer_norm_op.AddScalarParam(
299 OpLayerNorm.param_epsilon,
300 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32,
301 {QCOM_DATA: np.float32(epsilon)},
302 )
303 layer_norm_op.AddTensorParam(
304 OpLayerNorm.param_axes,
305 PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32,
306 len(axis_shape),
307 axis_shape,
308 np.array(axis, dtype=np.uint32),
309 True,
310 )
311 ```
312 By checking the `Shape` property of parameter in [Qualcomm AI Engine Direct Manual](#qualcomm-ai-engine-direct), it should be clear which API to be used. e.g.:
313 - "epsilon" > __Shape__: scalar
314 - "axes" > __Shape__: 1D of shape[M]
315
316 The function signature of AddScalarParam is:
317 - **name**: string maps to the operator name in Qualcomm AI Engine Direct manual
318 - **data_type**: type compatible with QNN SDK, e.g. `QNN_DATATYPE_FLOAT_32`, `QNN_DATATYPE_UINT_32`, etc.
319 - **attr**: dictionary for shipping data, currently only `QCOM_DATA` key is used
320
321 The function signature of AddTensorParam is:
322 - **name**: string maps to the operator name in Qualcomm AI Engine Direct manual
323 - **data_type**: type compatible with QNN SDK, e.g. `QNN_DATATYPE_FLOAT_32`, `QNN_DATATYPE_UINT_32`, etc.
324 - **rank**: dimensions of tensor
325 - **dims**: shape of tensor
326 - **data**: tesnor data
327 - **copy_data**: user should specify to True for constant parameters
328
3298. Last, return operator object for partitioner to conduct validation:
330 ```python
331 return layer_norm_op
332 ```
333 Also update the `__init__.py` for `register_node_visitor` to work properly:
334 ```python
335 from . import (
336 ...
337 op_index_put,
338 # please insert codes in alphabetical order
339 op_layer_norm,
340 op_linear,
341 ...
342 )
343
344 __all__ = [
345 ...
346 op_index_put,
347 # please insert codes in alphabetical order
348 op_layer_norm,
349 op_linear,
350 ...
351 ]
352 ```
353
354### Quantizer Annotation
355The operator now should be functional for Qualcomm backends. For operator to work in fixed-precision, we should also make `QnnQuantizer` to correctly insert observers for recording calibrated encodings. Please read more on the [Quantization Annotation Tutorial](../quantizer//README.md).
356
357## Issues
358Please refer to the [issue section](../README.md#issues) for more information.
359
360## Pull Requests
361Please refer to the [PR section](../README.md#pull-requests) for more information.
362