xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/README.md (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Contribution for Operator Annotation
2*523fa7a6SAndroid Build Coastguard WorkerThank you for contributing to Qualcomm AI Engine Direct delegate for ExecuTorch. Reading and following these guidelines will help you quickly get the essentials of annotating an operator in `QnnQuantizer` to unblock yourself and land pull requests more efficiently.
3*523fa7a6SAndroid Build Coastguard Worker
4*523fa7a6SAndroid Build Coastguard Worker## Sections
5*523fa7a6SAndroid Build Coastguard Worker* [References](#references)
6*523fa7a6SAndroid Build Coastguard Worker* [Getting Started](#getting-started)
7*523fa7a6SAndroid Build Coastguard Worker* [Issues](#issues)
8*523fa7a6SAndroid Build Coastguard Worker* [Pull Requests](#pull-requests)
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Worker## References
11*523fa7a6SAndroid Build Coastguard Worker### Qualcomm AI Engine Direct
12*523fa7a6SAndroid Build Coastguard Worker- [Operator Definitions for HTP](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/HtpOpDefSupplement.html)
13*523fa7a6SAndroid Build Coastguard Worker
14*523fa7a6SAndroid Build Coastguard Worker### PyTorch
15*523fa7a6SAndroid Build Coastguard Worker- [ATen Operator Definitions](https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/native)
16*523fa7a6SAndroid Build Coastguard Worker
17*523fa7a6SAndroid Build Coastguard Worker## Getting Started
18*523fa7a6SAndroid Build Coastguard WorkerBefore extending operator for quantization annotation, please make sure the operator builder has been well-implemented (learn more on this [tutorial](../builders/README.md)).
19*523fa7a6SAndroid Build Coastguard Worker### Behavior of Annotation
20*523fa7a6SAndroid Build Coastguard WorkerIn order to conduct PTQ for floating point precision graph, observers are required to be inserted after each graph nodes. The observed numeric range will go through different algorithms and return statistics of `scale`, `offset` to represent data in fixed point.<br/><br/>
21*523fa7a6SAndroid Build Coastguard Worker**Stages could be shown as**:
22*523fa7a6SAndroid Build Coastguard Worker- Floating point `nn.Module` after `torch.export.export`
23*523fa7a6SAndroid Build Coastguard Worker    ```mermaid
24*523fa7a6SAndroid Build Coastguard Worker    flowchart TB
25*523fa7a6SAndroid Build Coastguard Worker        input & kernel & bias --> id1(convolution) --> output
26*523fa7a6SAndroid Build Coastguard Worker    ```
27*523fa7a6SAndroid Build Coastguard Worker
28*523fa7a6SAndroid Build Coastguard Worker- Inserting observers for inspecting numeric range
29*523fa7a6SAndroid Build Coastguard Worker    ```mermaid
30*523fa7a6SAndroid Build Coastguard Worker    flowchart TB
31*523fa7a6SAndroid Build Coastguard Worker        input --> id2(input_act_obs) --> id1(convolution) --> id3(output_act_obs) --> output
32*523fa7a6SAndroid Build Coastguard Worker        kernel --> id4(weight_obs) --> id1(convolution)
33*523fa7a6SAndroid Build Coastguard Worker        bias --> id5(bias_obs) --> id1(convolution)
34*523fa7a6SAndroid Build Coastguard Worker    ```
35*523fa7a6SAndroid Build Coastguard Worker
36*523fa7a6SAndroid Build Coastguard Worker- Cascade QDQ pairs after landing encodings
37*523fa7a6SAndroid Build Coastguard Worker    ```mermaid
38*523fa7a6SAndroid Build Coastguard Worker    flowchart TB
39*523fa7a6SAndroid Build Coastguard Worker        input --> id2(Q_i) --> id3(DQ_i) --> id1(convolution) --> id4(Q_o) --> id5(DQ_o) --> output
40*523fa7a6SAndroid Build Coastguard Worker        kernel --> id6(Q_k) --> id7(DQ_k) --> id1(convolution)
41*523fa7a6SAndroid Build Coastguard Worker        bias --> id8(Q_b) --> id9(DQ_b) --> id1(convolution)
42*523fa7a6SAndroid Build Coastguard Worker    ```
43*523fa7a6SAndroid Build Coastguard WorkerQualcomm backend will consume the generated encodings and lower operators with fixed precision. This tutorial will guide you through the details of inserting observer and some useful utilies.
44*523fa7a6SAndroid Build Coastguard Worker
45*523fa7a6SAndroid Build Coastguard Worker### Register Annotation via Operator Type
46*523fa7a6SAndroid Build Coastguard WorkerLet's start with hooking callback for designated operator target:
47*523fa7a6SAndroid Build Coastguard Worker```python
48*523fa7a6SAndroid Build Coastguard Workerdef register_annotator(ops: List[OpOverload]):
49*523fa7a6SAndroid Build Coastguard Worker    def decorator(annotator: Callable):
50*523fa7a6SAndroid Build Coastguard Worker        for op in ops:
51*523fa7a6SAndroid Build Coastguard Worker            OP_ANNOTATOR[op] = annotator
52*523fa7a6SAndroid Build Coastguard Worker
53*523fa7a6SAndroid Build Coastguard Worker    return decorator
54*523fa7a6SAndroid Build Coastguard Worker```
55*523fa7a6SAndroid Build Coastguard WorkerThe `register_annotator` decorator provides a convenient way to attach your own annotation logic, which requires list of operator type as its input argument.<br/> For example, the torch activation functions have `copy`, `in-place` implementation with small difference appears in naming (an extra `_` postfix), which will map to the same [Core ATen](https://pytorch.org/docs/stable/torch.compiler_ir.html) operators after `to_edge`:
56*523fa7a6SAndroid Build Coastguard Worker```python
57*523fa7a6SAndroid Build Coastguard Worker@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
58*523fa7a6SAndroid Build Coastguard Worker```
59*523fa7a6SAndroid Build Coastguard WorkerWhere `torch.ops.aten.relu.default` / `torch.ops.aten.relu_.default` map to `copy` / `in-place` version and both will be converted into `torch.ops.aten.relu.default` ultimately.<br/><br>
60*523fa7a6SAndroid Build Coastguard Worker
61*523fa7a6SAndroid Build Coastguard WorkerThe function signature is defined as follow with two arguments:
62*523fa7a6SAndroid Build Coastguard Worker```python
63*523fa7a6SAndroid Build Coastguard Workerdef annotate_xxx(node: Node, quantization_config: QuantizationConfig) -> None:
64*523fa7a6SAndroid Build Coastguard Worker```
65*523fa7a6SAndroid Build Coastguard Worker- __node__: graph node required to be observed
66*523fa7a6SAndroid Build Coastguard Worker- __quantization_config__: data structure describing quantization configurations for IO activation / weight / bias
67*523fa7a6SAndroid Build Coastguard Worker
68*523fa7a6SAndroid Build Coastguard Worker### Example of Conv2d Annotation
69*523fa7a6SAndroid Build Coastguard WorkerConv2d accepts up to three input tensors: `input activation`, `kernel`, `bias`. There are constraints imposed by [Qualcomm AI Engine Direct Manual](https://docs.qualcomm.com/bundle/publicresource/topics/80-63442-50/HtpOpDefSupplement.html#conv2d).<br/>
70*523fa7a6SAndroid Build Coastguard WorkerTake 8-bit fixed point as example:
71*523fa7a6SAndroid Build Coastguard Worker- __weight__: must be symmetrically quantized if per-channel observer is applied
72*523fa7a6SAndroid Build Coastguard Worker- __bias__: must have `QNN_DATATYPE_SFIXED_POINT_32` and be symmetrically quantized with expected encoding `scales = weight.scales * input.scale`, `offset = 0` if per-channel observer is applied.
73*523fa7a6SAndroid Build Coastguard Worker
74*523fa7a6SAndroid Build Coastguard WorkerLet's look at the simplified per-channel quantization configuration used in `QnnQuantizer`:
75*523fa7a6SAndroid Build Coastguard Worker```python
76*523fa7a6SAndroid Build Coastguard Workerdef ptq_per_channel_quant_config(
77*523fa7a6SAndroid Build Coastguard Worker    act_dtype=torch.uint8, weight_dtype=torch.int8
78*523fa7a6SAndroid Build Coastguard Worker) -> QuantizationConfig:
79*523fa7a6SAndroid Build Coastguard Worker    ...
80*523fa7a6SAndroid Build Coastguard Worker    act_quantization_spec = QuantizationSpec(
81*523fa7a6SAndroid Build Coastguard Worker        dtype=act_dtype,
82*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(act_dtype).min,
83*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(act_dtype).max,
84*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_tensor_affine,
85*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args),
86*523fa7a6SAndroid Build Coastguard Worker    )
87*523fa7a6SAndroid Build Coastguard Worker
88*523fa7a6SAndroid Build Coastguard Worker    weight_quantization_spec = QuantizationSpec(
89*523fa7a6SAndroid Build Coastguard Worker        dtype=torch.int8,
90*523fa7a6SAndroid Build Coastguard Worker        quant_min=torch.iinfo(weight_dtype).min + 1,
91*523fa7a6SAndroid Build Coastguard Worker        quant_max=torch.iinfo(weight_dtype).max,
92*523fa7a6SAndroid Build Coastguard Worker        qscheme=torch.per_channel_symmetric,
93*523fa7a6SAndroid Build Coastguard Worker        ch_axis=0,
94*523fa7a6SAndroid Build Coastguard Worker        observer_or_fake_quant_ctr=PerChannelMinMaxObserver.with_args(**extra_args),
95*523fa7a6SAndroid Build Coastguard Worker    )
96*523fa7a6SAndroid Build Coastguard Worker
97*523fa7a6SAndroid Build Coastguard Worker    bias_quantization_spec = _derived_bias_quant_spec
98*523fa7a6SAndroid Build Coastguard Worker
99*523fa7a6SAndroid Build Coastguard Worker    quantization_config = QuantizationConfig(
100*523fa7a6SAndroid Build Coastguard Worker        input_activation=act_quantization_spec,
101*523fa7a6SAndroid Build Coastguard Worker        output_activation=act_quantization_spec,
102*523fa7a6SAndroid Build Coastguard Worker        weight=weight_quantization_spec,
103*523fa7a6SAndroid Build Coastguard Worker        bias=bias_quantization_spec,
104*523fa7a6SAndroid Build Coastguard Worker    )
105*523fa7a6SAndroid Build Coastguard Worker
106*523fa7a6SAndroid Build Coastguard Worker    return quantization_config
107*523fa7a6SAndroid Build Coastguard Worker```
108*523fa7a6SAndroid Build Coastguard WorkerHere we choose `torch.uint8` + `MinMaxObserver` for better converage of IO activation and apply rules to `weight` w/`PerChannelMinMaxObserver`, `bias` w/`_derived_bias_quant_spec` (a callable method to calculate encoding in desired way) to meet aforementioned constraints. The well-defined `quantizaton_config` will then be shipped to callback for annotation.<br/>
109*523fa7a6SAndroid Build Coastguard Worker
110*523fa7a6SAndroid Build Coastguard WorkerNow, we can start to fill in the function body:
111*523fa7a6SAndroid Build Coastguard Worker- Register annotator
112*523fa7a6SAndroid Build Coastguard Worker    ```python
113*523fa7a6SAndroid Build Coastguard Worker    @register_annotator(
114*523fa7a6SAndroid Build Coastguard Worker        [
115*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.conv2d.default,
116*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.conv1d.default,
117*523fa7a6SAndroid Build Coastguard Worker            torch.ops.aten.conv_transpose2d.input,
118*523fa7a6SAndroid Build Coastguard Worker        ]
119*523fa7a6SAndroid Build Coastguard Worker    )
120*523fa7a6SAndroid Build Coastguard Worker    def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
121*523fa7a6SAndroid Build Coastguard Worker    ```
122*523fa7a6SAndroid Build Coastguard Worker    There are multiple targets expected to meet our annotation criteria, it's encouraged to do so for code reuse.
123*523fa7a6SAndroid Build Coastguard Worker
124*523fa7a6SAndroid Build Coastguard Worker- Define map of input quantization spec
125*523fa7a6SAndroid Build Coastguard Worker    ```python
126*523fa7a6SAndroid Build Coastguard Worker        if _is_annotated([node]):
127*523fa7a6SAndroid Build Coastguard Worker            return
128*523fa7a6SAndroid Build Coastguard Worker
129*523fa7a6SAndroid Build Coastguard Worker        input_qspec_map = {}
130*523fa7a6SAndroid Build Coastguard Worker
131*523fa7a6SAndroid Build Coastguard Worker        # annotate input activation
132*523fa7a6SAndroid Build Coastguard Worker        input_act = node.args[0]
133*523fa7a6SAndroid Build Coastguard Worker        input_spec = quantization_config.input_activation
134*523fa7a6SAndroid Build Coastguard Worker        input_qspec_map[input_act] = input_spec
135*523fa7a6SAndroid Build Coastguard Worker
136*523fa7a6SAndroid Build Coastguard Worker        # annotate kernel
137*523fa7a6SAndroid Build Coastguard Worker        kernel = node.args[1]
138*523fa7a6SAndroid Build Coastguard Worker        input_qspec_map[kernel] = quantization_config.weight
139*523fa7a6SAndroid Build Coastguard Worker
140*523fa7a6SAndroid Build Coastguard Worker        # annotate bias
141*523fa7a6SAndroid Build Coastguard Worker        if len(node.args) > 2:
142*523fa7a6SAndroid Build Coastguard Worker            bias = node.args[2]
143*523fa7a6SAndroid Build Coastguard Worker            input_qspec_map[bias] = quantization_config.bias(node)
144*523fa7a6SAndroid Build Coastguard Worker    ```
145*523fa7a6SAndroid Build Coastguard Worker    We first check if current graph node has been annotated. If not, an `input_qspec_map` dictionary required by PyTorch framework will be declared for providing mapping between graph nodes and their configurations.<br/>
146*523fa7a6SAndroid Build Coastguard Worker    The parameters' order could be found [here](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/Convolution.cpp) mentioned in [ATen Operator Definitions](#pytorch). Since bias node is optional, the implementation will invoke `_derived_bias_quant_spec` to calculate the per-channel bias encoding only if it exists.
147*523fa7a6SAndroid Build Coastguard Worker
148*523fa7a6SAndroid Build Coastguard Worker- Update node's meta with framework compatible data structure
149*523fa7a6SAndroid Build Coastguard Worker    ```python
150*523fa7a6SAndroid Build Coastguard Worker        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
151*523fa7a6SAndroid Build Coastguard Worker            input_qspec_map=input_qspec_map,
152*523fa7a6SAndroid Build Coastguard Worker            output_qspec=quantization_config.output_activation,
153*523fa7a6SAndroid Build Coastguard Worker            _annotated=True,
154*523fa7a6SAndroid Build Coastguard Worker        )
155*523fa7a6SAndroid Build Coastguard Worker    ```
156*523fa7a6SAndroid Build Coastguard Worker    After done processing `input_qspec_map`, it's required to have it in node's meta with special tag (`QUANT_ANNOTATION_KEY`) for `convert_pt2e` to properly insert observers.
157*523fa7a6SAndroid Build Coastguard Worker
158*523fa7a6SAndroid Build Coastguard Worker### Common Annotators
159*523fa7a6SAndroid Build Coastguard WorkerFor operators without extra parameters to be observed, there are pre-defined annotation method for convenience:
160*523fa7a6SAndroid Build Coastguard Worker- Single in single out operators, e.g.:
161*523fa7a6SAndroid Build Coastguard Worker    ```python
162*523fa7a6SAndroid Build Coastguard Worker    @register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
163*523fa7a6SAndroid Build Coastguard Worker    def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
164*523fa7a6SAndroid Build Coastguard Worker        annotate_single_in_single_out(node, quantization_config)
165*523fa7a6SAndroid Build Coastguard Worker    ```
166*523fa7a6SAndroid Build Coastguard Worker
167*523fa7a6SAndroid Build Coastguard Worker- Binary in single out operators, e.g.:
168*523fa7a6SAndroid Build Coastguard Worker    ```python
169*523fa7a6SAndroid Build Coastguard Worker    @register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor])
170*523fa7a6SAndroid Build Coastguard Worker    def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
171*523fa7a6SAndroid Build Coastguard Worker        annotate_binary(node, quantization_config)
172*523fa7a6SAndroid Build Coastguard Worker    ```
173*523fa7a6SAndroid Build Coastguard Worker
174*523fa7a6SAndroid Build Coastguard Worker- Shared encodings between input / output, e.g.:<br/>
175*523fa7a6SAndroid Build Coastguard Worker    ```python
176*523fa7a6SAndroid Build Coastguard Worker    # For operators without arithmetical function, IOs are expected to own the same encodings.
177*523fa7a6SAndroid Build Coastguard Worker    @register_annotator([torch.ops.aten.transpose.int])
178*523fa7a6SAndroid Build Coastguard Worker    def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None:
179*523fa7a6SAndroid Build Coastguard Worker        annotate_in_out_obs_sharing_op(node, quantization_config)
180*523fa7a6SAndroid Build Coastguard Worker        if not _is_annotated([node]):
181*523fa7a6SAndroid Build Coastguard Worker            annotate_single_in_single_out(node, quantization_config)
182*523fa7a6SAndroid Build Coastguard Worker    ```
183*523fa7a6SAndroid Build Coastguard Worker    This annotator only works for single-in-single-out scenario with node's input that has already been annotated. If not, we still need to invoke `annotate_single_in_single_out` again (this path should be less likely).
184*523fa7a6SAndroid Build Coastguard Worker
185*523fa7a6SAndroid Build Coastguard Worker## Issues
186*523fa7a6SAndroid Build Coastguard WorkerPlease refer to the [issue section](../README.md#issues) for more information.
187*523fa7a6SAndroid Build Coastguard Worker
188*523fa7a6SAndroid Build Coastguard Worker## Pull Requests
189*523fa7a6SAndroid Build Coastguard WorkerPlease refer to the [PR section](../README.md#pull-requests) for more information.
190