1# Contribution for Operator Annotation 2Thank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of annotating an operator in `QnnQuantizer` to unblock yourself and land pull requests more efficiently. 3 4## Sections 5* [References](#references) 6* [Getting Started](#getting-started) 7* [Issues](#issues) 8* [Pull Requests](#pull-requests) 9 10## References 11### Qualcomm AI Engine Direct 12- [Operator Definitions for HTP](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/HtpOpDefSupplement.html) 13 14### PyTorch 15- [ATen Operator Definitions](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native) 16 17## Getting Started 18Before extending operator for quantization annotation, please make sure the operator builder has been well-implemented (learn more on this [tutorial](../builders/README.md)). 19### Behavior of Annotation 20In order to conduct PTQ for floating point precision graph, observers are required to be inserted after each graph nodes. The observed numeric range will go through different algorithms and return statistics of `scale`, `offset` to represent data in fixed point.<br/><br/> 21**Stages could be shown as**: 22- Floating point `nn.Module` after `torch.export.export` 23 ```mermaid 24 flowchart TB 25 input & kernel & bias --> id1(convolution) --> output 26 ``` 27 28- Inserting observers for inspecting numeric range 29 ```mermaid 30 flowchart TB 31 input --> id2(input_act_obs) --> id1(convolution) --> id3(output_act_obs) --> output 32 kernel --> id4(weight_obs) --> id1(convolution) 33 bias --> id5(bias_obs) --> id1(convolution) 34 ``` 35 36- Cascade QDQ pairs after landing encodings 37 ```mermaid 38 flowchart TB 39 input --> id2(Q_i) --> id3(DQ_i) --> id1(convolution) --> id4(Q_o) --> id5(DQ_o) --> output 40 kernel --> id6(Q_k) --> id7(DQ_k) --> id1(convolution) 41 bias --> id8(Q_b) --> id9(DQ_b) --> id1(convolution) 42 ``` 43Qualcomm backend will consume the generated encodings and lower operators with fixed precision. This tutorial will guide you through the details of inserting observer and some useful utilies. 44 45### Register Annotation via Operator Type 46Let's start with hooking callback for designated operator target: 47```python 48def register_annotator(ops: List[OpOverload]): 49 def decorator(annotator: Callable): 50 for op in ops: 51 OP_ANNOTATOR[op] = annotator 52 53 return decorator 54``` 55The `register_annotator` decorator provides a convenient way to attach your own annotation logic, which requires list of operator type as its input argument.<br/> For example, the torch activation functions have `copy`, `in-place` implementation with small difference appears in naming (an extra `_` postfix), which will map to the same [Core ATen](https://pytorch.org/docs/stable/torch.compiler_ir.html) operators after `to_edge`: 56```python 57@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default]) 58``` 59Where `torch.ops.aten.relu.default` / `torch.ops.aten.relu_.default` map to `copy` / `in-place` version and both will be converted into `torch.ops.aten.relu.default` ultimately.<br/><br> 60 61The function signature is defined as follow with two arguments: 62```python 63def annotate_xxx(node: Node, quantization_config: QuantizationConfig) -> None: 64``` 65- __node__: graph node required to be observed 66- __quantization_config__: data structure describing quantization configurations for IO activation / weight / bias 67 68### Example of Conv2d Annotation 69Conv2d accepts up to three input tensors: `input activation`, `kernel`, `bias`. There are constraints imposed by [Qualcomm AI Engine Direct Manual](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/HtpOpDefSupplement.html#conv2d).<br/> 70Take 8-bit fixed point as example: 71- __weight__: must be symmetrically quantized if per-channel observer is applied 72- __bias__: must have `QNN_DATATYPE_SFIXED_POINT_32` and be symmetrically quantized with expected encoding `scales = weight.scales * input.scale`, `offset = 0` if per-channel observer is applied. 73 74Let's look at the simplified per-channel quantization configuration used in `QnnQuantizer`: 75```python 76def ptq_per_channel_quant_config( 77 act_dtype=torch.uint8, weight_dtype=torch.int8 78) -> QuantizationConfig: 79 ... 80 act_quantization_spec = QuantizationSpec( 81 dtype=act_dtype, 82 quant_min=torch.iinfo(act_dtype).min, 83 quant_max=torch.iinfo(act_dtype).max, 84 qscheme=torch.per_tensor_affine, 85 observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), 86 ) 87 88 weight_quantization_spec = QuantizationSpec( 89 dtype=torch.int8, 90 quant_min=torch.iinfo(weight_dtype).min + 1, 91 quant_max=torch.iinfo(weight_dtype).max, 92 qscheme=torch.per_channel_symmetric, 93 ch_axis=0, 94 observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args), 95 ) 96 97 bias_quantization_spec = _derived_bias_quant_spec 98 99 quantization_config = QuantizationConfig( 100 input_activation=act_quantization_spec, 101 output_activation=act_quantization_spec, 102 weight=weight_quantization_spec, 103 bias=bias_quantization_spec, 104 ) 105 106 return quantization_config 107``` 108Here we choose `torch.uint8` + `MinMaxObserver` for better converage of IO activation and apply rules to `weight` w/`PerChannelMinMaxObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.<br/> 109 110Now, we can start to fill in the function body: 111- Register annotator 112 ```python 113 @register_annotator( 114 [ 115 torch.ops.aten.conv2d.default, 116 torch.ops.aten.conv1d.default, 117 torch.ops.aten.conv_transpose2d.input, 118 ] 119 ) 120 def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None: 121 ``` 122 There are multiple targets expected to meet our annotation criteria, it's encouraged to do so for code reuse. 123 124- Define map of input quantization spec 125 ```python 126 if _is_annotated([node]): 127 return 128 129 input_qspec_map = {} 130 131 # annotate input activation 132 input_act = node.args[0] 133 input_spec = quantization_config.input_activation 134 input_qspec_map[input_act] = input_spec 135 136 # annotate kernel 137 kernel = node.args[1] 138 input_qspec_map[kernel] = quantization_config.weight 139 140 # annotate bias 141 if len(node.args) > 2: 142 bias = node.args[2] 143 input_qspec_map[bias] = quantization_config.bias(node) 144 ``` 145 We first check if current graph node has been annotated. If not, an `input_qspec_map` dictionary required by PyTorch framework will be declared for providing mapping between graph nodes and their configurations.<br/> 146 The parameters' order could be found [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Convolution.cpp) mentioned in [ATen Operator Definitions](#pytorch). Since bias node is optional, the implementation will invoke `_derived_bias_quant_spec` to calculate the per-channel bias encoding only if it exists. 147 148- Update node's meta with framework compatible data structure 149 ```python 150 node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation( 151 input_qspec_map=input_qspec_map, 152 output_qspec=quantization_config.output_activation, 153 _annotated=True, 154 ) 155 ``` 156 After done processing `input_qspec_map`, it's required to have it in node's meta with special tag (`QUANT_ANNOTATION_KEY`) for `convert_pt2e` to properly insert observers. 157 158### Common Annotators 159For operators without extra parameters to be observed, there are pre-defined annotation method for convenience: 160- Single in single out operators, e.g.: 161 ```python 162 @register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default]) 163 def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None: 164 annotate_single_in_single_out(node, quantization_config) 165 ``` 166 167- Binary in single out operators, e.g.: 168 ```python 169 @register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor]) 170 def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None: 171 annotate_binary(node, quantization_config) 172 ``` 173 174- Shared encodings between input / output, e.g.:<br/> 175 ```python 176 # For operators without arithmetical function, IOs are expected to own the same encodings. 177 @register_annotator([torch.ops.aten.transpose.int]) 178 def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None: 179 annotate_in_out_obs_sharing_op(node, quantization_config) 180 if not _is_annotated([node]): 181 annotate_single_in_single_out(node, quantization_config) 182 ``` 183 This annotator only works for single-in-single-out scenario with node's input that has already been annotated. If not, we still need to invoke `annotate_single_in_single_out` again (this path should be less likely). 184 185## Issues 186Please refer to the [issue section](../README.md#issues) for more information. 187 188## Pull Requests 189Please refer to the [PR section](../README.md#pull-requests) for more information. 190