xref: /aosp_15_r20/external/executorch/backends/qualcomm/quantizer/annotators.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import numbers
7import operator
8from functools import partial
9from typing import Callable, Dict, List, Sequence, Tuple
10
11import torch
12from torch._ops import OpOverload
13
14from torch._subclasses import FakeTensor
15from torch.ao.quantization.fake_quantize import FixedQParamsFakeQuantize
16
17from torch.ao.quantization.observer import FixedQParamsObserver
18from torch.ao.quantization.quantizer import (
19    DerivedQuantizationSpec,
20    QuantizationAnnotation,
21    QuantizationSpec,
22    SharedQuantizationSpec,
23)
24from torch.ao.quantization.quantizer.utils import (
25    _annotate_input_qspec_map,
26    _annotate_output_qspec,
27)
28from torch.fx import Node
29
30from .qconfig import (
31    get_16a16w_qnn_ptq_config,
32    get_16a4w_qnn_qat_config,
33    get_8a8w_qnn_qat_config,
34    QuantizationConfig,
35)
36
37
38QUANT_ANNOTATION_KEY = "quantization_annotation"
39OP_ANNOTATOR: Dict[OpOverload, Callable] = {}
40
41
42def register_annotator(ops: List[OpOverload]):
43    def decorator(annotator: Callable):
44        for op in ops:
45            OP_ANNOTATOR[op] = annotator
46
47    return decorator
48
49
50def _is_annotated(nodes: List[Node]):
51    """
52    Given a list of nodes (that represents an operator pattern),
53    return True if any of the node
54    is annotated, otherwise return False
55    """
56    annotated = False
57    for node in nodes:
58        annotated = annotated or (
59            QUANT_ANNOTATION_KEY in node.meta
60            and node.meta[QUANT_ANNOTATION_KEY]._annotated
61        )
62    return annotated
63
64
65def _is_float_tensor(node: Node):
66    """Check if the node's tensor is a float tensor, so that we can skip quantization for the node
67    since observers only works with float Tensors
68    """
69    if (
70        not isinstance(node, Node)
71        or "val" not in node.meta
72        or not isinstance(node.meta["val"], FakeTensor)
73    ):
74        return False
75    return node.meta["val"].dtype == torch.float32
76
77
78def _mark_nodes_as_annotated(nodes: List[Node]):
79    for node in nodes:
80        if QUANT_ANNOTATION_KEY not in node.meta:
81            node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation()
82        node.meta[QUANT_ANNOTATION_KEY]._annotated = True
83
84
85def annotate_in_out_obs_sharing_op(
86    node: Node, quantization_config: QuantizationConfig
87) -> None:
88    if _is_annotated([node]):
89        return
90
91    input_act = node.args[0]
92    assert isinstance(input_act, Node)
93
94    # only annotate input output sharing operator
95    # when the output of the input node is annotated
96    if (
97        QUANT_ANNOTATION_KEY not in input_act.meta
98        or not input_act.meta[QUANT_ANNOTATION_KEY]._annotated
99        or input_act.meta[QUANT_ANNOTATION_KEY].output_qspec is None
100    ):
101        return
102
103    act_qspec = SharedQuantizationSpec(input_act)
104    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
105        input_qspec_map={
106            input_act: act_qspec,
107        },
108        output_qspec=act_qspec,
109        _annotated=True,
110    )
111
112
113def annotate_single_in_single_out(
114    node: Node, quantization_config: QuantizationConfig
115) -> None:
116    if _is_annotated([node]):
117        return
118
119    input_qspec_map = {}
120    input_act = node.args[0]
121    assert isinstance(input_act, Node)
122    input_qspec_map[input_act] = quantization_config.input_activation
123
124    if _is_float_tensor(node):
125        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
126            input_qspec_map=input_qspec_map,
127            output_qspec=quantization_config.output_activation,
128            _annotated=True,
129        )
130
131
132@register_annotator([torch.ops.aten.topk.default])
133def annotate_topk(node: Node, quantization_config: QuantizationConfig) -> None:
134    if _is_annotated([node]):
135        return
136    # We can use single_in_single_out since we don't want to quantize indices output
137    annotate_single_in_single_out(node, quantization_config)
138
139
140def annotate_binary(node: Node, quantization_config: QuantizationConfig) -> None:
141    if _is_annotated([node]):
142        return
143
144    input_act_qspec = quantization_config.input_activation
145    output_act_qspec = (
146        quantization_config.output_activation if _is_float_tensor(node) else None
147    )
148
149    input_qspec_map = {}
150    input_act0 = node.args[0]
151    if _is_float_tensor(input_act0):
152        input_qspec_map[input_act0] = input_act_qspec
153
154    input_act1 = node.args[1]
155    if _is_float_tensor(input_act1):
156        input_qspec_map[input_act1] = input_act_qspec
157
158    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
159        input_qspec_map=input_qspec_map,
160        output_qspec=output_act_qspec,
161        _annotated=True,
162    )
163
164
165@register_annotator([torch.ops.aten.add, torch.ops.aten.add.Tensor])
166def annotate_add(node: Node, quantization_config: QuantizationConfig) -> None:
167    annotate_binary(node, quantization_config)
168
169
170@register_annotator([torch.ops.aten.sub, torch.ops.aten.sub.Tensor])
171def annotate_sub(node: Node, quantization_config: QuantizationConfig) -> None:
172    annotate_binary(node, quantization_config)
173
174
175@register_annotator(
176    [torch.ops.aten.mul, torch.ops.aten.mul.Tensor, torch.ops.aten.mul.Scalar]
177)
178def annotate_mul(node: Node, quantization_config: QuantizationConfig) -> None:
179    annotate_binary(node, quantization_config)
180
181
182@register_annotator(
183    [torch.ops.aten.div, torch.ops.aten.div.Tensor, torch.ops.aten.divide.Tensor]
184)
185def annotate_div(node: Node, quantization_config: QuantizationConfig) -> None:
186    def _derived_inp1_const_div_quant_spec(
187        node: torch.fx.Node, output_qspec: QuantizationSpec
188    ) -> DerivedQuantizationSpec:
189        def _derive_div_qparams_fn(
190            obs_or_fqs: List,
191            const_val: float,
192        ) -> Tuple[torch.Tensor, torch.Tensor]:
193            inp_0_obs_or_fq = obs_or_fqs[0]
194            inp_0_scale, inp_0_zp = inp_0_obs_or_fq.calculate_qparams()
195            derived_scale = inp_0_scale / const_val
196            return (derived_scale, inp_0_zp)
197
198        inp_0 = node.args[0]
199        const_inp_1 = node.args[1]
200        _derive_div_qparams_with_const_fn = partial(
201            _derive_div_qparams_fn, const_val=const_inp_1
202        )
203
204        q_min = (
205            torch.iinfo(output_qspec.dtype).min
206            if output_qspec.quant_min is None
207            else output_qspec.quant_min
208        )
209        q_max = (
210            torch.iinfo(output_qspec.dtype).max
211            if output_qspec.quant_max is None
212            else output_qspec.quant_max
213        )
214        return DerivedQuantizationSpec(
215            derived_from=[(inp_0, node)],
216            derive_qparams_fn=_derive_div_qparams_with_const_fn,
217            dtype=output_qspec.dtype,
218            quant_min=q_min,
219            quant_max=q_max,
220            ch_axis=0,
221            qscheme=output_qspec.qscheme,
222        )
223
224    if [a for a in node.args if isinstance(a, Node)]:
225        annotate_binary(node, quantization_config)
226    # special constant divisor case
227    elif isinstance(node.args[0], Node) and isinstance(node.args[1], numbers.Number):
228        if _is_annotated([node]):
229            return
230
231        input_act_qspec = quantization_config.input_activation
232        output_act_qspec = _derived_inp1_const_div_quant_spec(
233            node, quantization_config.output_activation
234        )
235        input_qspec_map = {}
236        input_act0 = node.args[0]
237        if _is_float_tensor(input_act0):
238            input_qspec_map[input_act0] = input_act_qspec
239
240        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
241            input_qspec_map=input_qspec_map,
242            output_qspec=output_act_qspec,
243            _annotated=True,
244        )
245    else:
246        raise NotImplementedError(f"No quant annotation is implemented for {node}.")
247
248
249@register_annotator([torch.ops.aten.rsub.Scalar])
250def annotate_rsub(node: Node, quantization_config: QuantizationConfig) -> None:
251    annotate_binary(node, quantization_config)
252
253
254@register_annotator([torch.ops.aten.sum.dim_IntList])
255def annotate_sum(node: Node, quantization_config: QuantizationConfig) -> None:
256    annotate_binary(node, quantization_config)
257
258
259@register_annotator([torch.ops.aten.ceil.default])
260def annotate_ceil(node: Node, quantization_config: QuantizationConfig) -> None:
261    annotate_single_in_single_out(node, quantization_config)
262
263
264@register_annotator([torch.ops.aten.clamp.default])
265def annotate_clamp(node: Node, quantization_config: QuantizationConfig) -> None:
266    annotate_single_in_single_out(node, quantization_config)
267
268
269@register_annotator([torch.ops.aten.relu.default, torch.ops.aten.relu_.default])
270def annotate_relu(node: Node, quantization_config: QuantizationConfig) -> None:
271    annotate_single_in_single_out(node, quantization_config)
272
273
274@register_annotator([torch.ops.aten.tanh.default])
275def annotate_tanh(node: Node, quantization_config: QuantizationConfig) -> None:
276    annotate_single_in_single_out(node, quantization_config)
277
278
279@register_annotator(
280    [torch.ops.aten.hardswish.default, torch.ops.aten.hardswish_.default]
281)
282def annotate_hardswish(node: Node, quantization_config: QuantizationConfig) -> None:
283    annotate_single_in_single_out(node, quantization_config)
284
285
286@register_annotator(
287    [torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardsigmoid_.default]
288)
289def annotate_hardsigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
290    annotate_single_in_single_out(node, quantization_config)
291
292
293@register_annotator([torch.ops.aten.hardtanh.default, torch.ops.aten.hardtanh_.default])
294def annotate_hardtanh(node: Node, quantization_config: QuantizationConfig) -> None:
295    annotate_single_in_single_out(node, quantization_config)
296
297
298@register_annotator([torch.ops.aten.mean.default])
299def annotate_mean(node: Node, quantization_config: QuantizationConfig) -> None:
300    annotate_single_in_single_out(node, quantization_config)
301
302
303@register_annotator([torch.ops.aten.max_pool2d.default])
304def annotate_max_pool2d(node: Node, quantization_config: QuantizationConfig) -> None:
305    annotate_single_in_single_out(node, quantization_config)
306
307
308@register_annotator([torch.ops.aten.max_pool2d_with_indices.default])
309def annotate_max_pool2d_with_indices(
310    node: Node, quantization_config: QuantizationConfig
311) -> None:
312    annotate_single_in_single_out(node, quantization_config)
313
314
315@register_annotator([torch.ops.aten.adaptive_avg_pool2d.default])
316def annotate_adaptive_avgpool2d(
317    node: Node, quantization_config: QuantizationConfig
318) -> None:
319    annotate_single_in_single_out(node, quantization_config)
320
321
322@register_annotator([torch.ops.aten.avg_pool2d.default])
323def annotate_avgpool2d(node: Node, quantization_config: QuantizationConfig) -> None:
324    annotate_single_in_single_out(node, quantization_config)
325
326
327@register_annotator([torch.ops.aten.permute.default])
328def annotate_permute(node: Node, quantization_config: QuantizationConfig) -> None:
329    annotate_in_out_obs_sharing_op(node, quantization_config)
330    if not _is_annotated([node]):
331        annotate_single_in_single_out(node, quantization_config)
332
333
334@register_annotator(
335    [
336        torch.ops.aten.leaky_relu.default,
337        torch.ops.aten.leaky_relu_.default,
338        torch.ops.aten.prelu.default,
339    ]
340)
341def annotate_prelu(node: Node, quantization_config: QuantizationConfig) -> None:
342    annotate_single_in_single_out(node, quantization_config)
343
344
345@register_annotator([torch.ops.aten.view.default, torch.ops.aten._unsafe_view.default])
346def annotate_view(node: Node, quantization_config: QuantizationConfig) -> None:
347    annotate_in_out_obs_sharing_op(node, quantization_config)
348    if not _is_annotated([node]):
349        annotate_single_in_single_out(node, quantization_config)
350
351
352@register_annotator([torch.ops.aten.pixel_shuffle.default])
353def annotate_pixel_shuffle_default(
354    node: Node, quantization_config: QuantizationConfig
355) -> None:
356    annotate_single_in_single_out(node, quantization_config)
357
358
359@register_annotator([torch.ops.aten.pixel_unshuffle.default])
360def annotate_pixel_unshuffle_default(
361    node: Node, quantization_config: QuantizationConfig
362) -> None:
363    annotate_single_in_single_out(node, quantization_config)
364
365
366@register_annotator([torch.ops.aten.upsample_bilinear2d.vec])
367def annotate_upsample_bilinear2d(
368    node: Node, quantization_config: QuantizationConfig
369) -> None:
370    annotate_single_in_single_out(node, quantization_config)
371
372
373@register_annotator([torch.ops.aten.upsample_nearest2d.vec])
374def annotate_upsample_nearest2d(
375    node: Node, quantization_config: QuantizationConfig
376) -> None:
377    annotate_single_in_single_out(node, quantization_config)
378
379
380@register_annotator(
381    [
382        torch.ops.aten.softmax.int,
383        torch.ops.aten._softmax.default,
384        torch.ops.aten._safe_softmax.default,
385    ]
386)
387def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
388    annotate_single_in_single_out(node, quantization_config)
389
390
391@register_annotator([torch.ops.aten.log_softmax.int])
392def annotate_log_softmax(node: Node, quantization_config: QuantizationConfig) -> None:
393    annotate_single_in_single_out(node, quantization_config)
394
395
396@register_annotator([torch.ops.aten.pad.default])
397def annotate_pad(node: Node, quantization_config: QuantizationConfig) -> None:
398    annotate_single_in_single_out(node, quantization_config)
399
400
401@register_annotator([torch.ops.aten.reshape.default])
402def annotate_reshape(node: Node, quantization_config: QuantizationConfig) -> None:
403    annotate_single_in_single_out(node, quantization_config)
404
405
406@register_annotator([torch.ops.aten.select.int])
407def annotate_select(node: Node, quantization_config: QuantizationConfig) -> None:
408    annotate_single_in_single_out(node, quantization_config)
409
410
411@register_annotator([torch.ops.aten.mean.dim])
412def annotate_mean_dim(node: Node, quantization_config: QuantizationConfig) -> None:
413    annotate_single_in_single_out(node, quantization_config)
414
415
416@register_annotator([torch.ops.aten.slice.Tensor])
417def annotate_slice(node: Node, quantization_config: QuantizationConfig) -> None:
418    annotate_single_in_single_out(node, quantization_config)
419
420
421@register_annotator([torch.ops.aten.sqrt.default])
422def annotate_sqrt(node: Node, quantization_config: QuantizationConfig) -> None:
423    annotate_single_in_single_out(node, quantization_config)
424
425
426@register_annotator([torch.ops.aten.gelu.default])
427def annotate_gelu(node: Node, quantization_config: QuantizationConfig) -> None:
428    annotate_single_in_single_out(node, quantization_config)
429
430
431@register_annotator([torch.ops.aten.scaled_dot_product_attention.default])
432def annotate_scaled_dot_product_attention(
433    node: Node, quantization_config: QuantizationConfig
434) -> None:
435    annotate_single_in_single_out(node, quantization_config)
436
437
438@register_annotator(
439    [
440        torch.ops.aten.squeeze.default,
441        torch.ops.aten.squeeze.dim,
442        torch.ops.aten.squeeze_copy.dims,
443    ]
444)
445def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> None:
446    annotate_in_out_obs_sharing_op(node, quantization_config)
447    if not _is_annotated([node]):
448        annotate_single_in_single_out(node, quantization_config)
449
450
451@register_annotator([torch.ops.aten.rms_norm.default])
452def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None:
453    act_node = node.args[0]
454    weight_node = node.args[2]
455
456    if _is_annotated([node]):
457        return
458
459    # TODO current only support 16a16w
460    _annotate_input_qspec_map(
461        node,
462        act_node,
463        quantization_config.input_activation,
464    )
465
466    _annotate_input_qspec_map(
467        node,
468        weight_node,
469        quantization_config.input_activation,
470    )
471    nodes_to_mark_annotated = [node]
472    _annotate_output_qspec(node, quantization_config.output_activation)
473    _mark_nodes_as_annotated(nodes_to_mark_annotated)
474
475
476@register_annotator([torch.ops.aten.rsqrt.default])
477def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None:
478    annotate_single_in_single_out(node, quantization_config)
479
480
481@register_annotator([torch.ops.aten.sigmoid, torch.ops.aten.sigmoid.default])
482def annotate_sigmoid(node: Node, quantization_config: QuantizationConfig) -> None:
483    if _is_annotated([node]):
484        return
485
486    input_qspec_map = {}
487    input_act = node.args[0]
488    input_qspec_map[input_act] = quantization_config.input_activation
489
490    assert isinstance(input_act, Node)
491    out_qconf = quantization_config.output_activation
492
493    q_max = (
494        torch.iinfo(out_qconf.dtype).max
495        if out_qconf.quant_max is None
496        else out_qconf.quant_max
497    )
498    q_min = (
499        torch.iinfo(out_qconf.dtype).min
500        if out_qconf.quant_min is None
501        else out_qconf.quant_min
502    )
503
504    scale = 1 / (q_max - q_min + 1)
505
506    bias_obs_ctr = observer = FixedQParamsObserver.with_args(
507        scale=scale,
508        zero_point=0,
509        dtype=quantization_config.output_activation.dtype,
510        qscheme=torch.torch.per_tensor_affine,
511        quant_max=q_max,
512        quant_min=q_min,
513    )
514    if quantization_config in (
515        get_8a8w_qnn_qat_config(),
516        get_16a4w_qnn_qat_config(),
517    ):
518        bias_obs_ctr = FixedQParamsFakeQuantize.with_args(
519            observer=observer,
520            scale=scale,
521            zero_point=0,
522            dtype=quantization_config.output_activation.dtype,
523            qscheme=torch.torch.per_tensor_affine,
524            quant_max=q_max,
525            quant_min=q_min,
526        )
527
528    # make sigmoid map to the range between 0~1
529    out_act_quantization_spec = QuantizationSpec(
530        dtype=quantization_config.output_activation.dtype,
531        quant_max=q_max,
532        quant_min=q_min,
533        observer_or_fake_quant_ctr=bias_obs_ctr,
534        qscheme=torch.torch.per_tensor_affine,
535    )
536
537    if _is_float_tensor(node):
538        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
539            input_qspec_map=input_qspec_map,
540            output_qspec=out_act_quantization_spec,
541            _annotated=True,
542        )
543
544
545@register_annotator([torch.ops.aten.pow.Tensor_Scalar])
546def annotate_pow(node: Node, quantization_config: QuantizationConfig) -> None:
547    annotate_single_in_single_out(node, quantization_config)
548
549
550@register_annotator([torch.ops.aten.unsqueeze.default])
551def annotate_unsqueeze(node: Node, quantization_config: QuantizationConfig) -> None:
552    annotate_in_out_obs_sharing_op(node, quantization_config)
553    if not _is_annotated([node]):
554        annotate_single_in_single_out(node, quantization_config)
555
556
557@register_annotator(
558    [
559        torch.ops.aten.unsqueeze_copy.default,
560    ]
561)
562def annotate_unsqueeze_copy(
563    node: Node, quantization_config: QuantizationConfig
564) -> None:
565    annotate_in_out_obs_sharing_op(node, quantization_config)
566    if not _is_annotated([node]):
567        annotate_single_in_single_out(node, quantization_config)
568
569
570@register_annotator([torch.ops.aten.transpose.int])
571def annotate_transpose(node: Node, quantization_config: QuantizationConfig) -> None:
572    annotate_in_out_obs_sharing_op(node, quantization_config)
573    if not _is_annotated([node]):
574        annotate_single_in_single_out(node, quantization_config)
575
576
577@register_annotator([torch.ops.aten.embedding.default])
578def annotate_embedding(node: Node, quantization_config: QuantizationConfig) -> None:
579    weight = node.args[0]
580
581    input_qspec_map = {}
582    input_qspec_map[weight] = quantization_config.input_activation
583
584    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
585        input_qspec_map=input_qspec_map,
586        output_qspec=SharedQuantizationSpec((weight, node)),
587        _annotated=True,
588    )
589
590
591@register_annotator([torch.ops.aten.index.Tensor])
592def annotate_index(node: Node, quantization_config: QuantizationConfig) -> None:
593    annotate_in_out_obs_sharing_op(node, quantization_config)
594    if not _is_annotated([node]):
595        input_qspec_map = {}
596        input = node.args[0]
597        input_qspec_map[input] = quantization_config.input_activation
598        node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
599            input_qspec_map=input_qspec_map,
600            output_qspec=SharedQuantizationSpec((input, node)),
601            _annotated=True,
602        )
603
604
605@register_annotator(
606    [torch.ops.aten.index_put.default, torch.ops.aten.index_put_.default]
607)
608def annotate_index_put(node: Node, quantization_config: QuantizationConfig) -> None:
609    input = node.args[0]
610    value = node.args[2]
611
612    input_qspec_map = {}
613    input_qspec_map[input] = quantization_config.input_activation
614    input_qspec_map[value] = SharedQuantizationSpec((input, node))
615
616    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
617        input_qspec_map=input_qspec_map,
618        output_qspec=SharedQuantizationSpec((input, node)),
619        _annotated=True,
620    )
621
622
623@register_annotator([torch.ops.aten.expand.default])
624def annotate_expand(node: Node, quantization_config: QuantizationConfig) -> None:
625    annotate_in_out_obs_sharing_op(node, quantization_config)
626    if not _is_annotated([node]):
627        annotate_single_in_single_out(node, quantization_config)
628
629
630@register_annotator([torch.ops.aten.group_norm.default])
631def annotate_group_norm(node: Node, quantization_config: QuantizationConfig) -> None:
632    act_node = node.args[0]
633    weight_node = node.args[2]
634    bias_node = None
635    if len(node.args) > 2:
636        bias_node = node.args[3]
637
638    if _is_annotated([node]):
639        return
640
641    _annotate_input_qspec_map(
642        node,
643        act_node,
644        quantization_config.input_activation,
645    )
646    _annotate_input_qspec_map(
647        node,
648        weight_node,
649        quantization_config.weight,
650    )
651    nodes_to_mark_annotated = [node, weight_node]
652    if bias_node:
653        _annotate_input_qspec_map(
654            node,
655            bias_node,
656            quantization_config.bias,
657        )
658        nodes_to_mark_annotated.append(bias_node)
659    _annotate_output_qspec(node, quantization_config.output_activation)
660    _mark_nodes_as_annotated(nodes_to_mark_annotated)
661
662
663@register_annotator([torch.ops.aten.flatten.using_ints])
664def annotate_flatten(node: Node, quantization_config: QuantizationConfig) -> None:
665    annotate_in_out_obs_sharing_op(node, quantization_config)
666    if not _is_annotated([node]):
667        annotate_single_in_single_out(node, quantization_config)
668
669
670@register_annotator([torch.ops.aten.stack.default])
671def annotate_stack(node: Node, quantization_config: QuantizationConfig) -> None:
672    input_qspec_map = {}
673    for input_act in node.args[0]:
674        assert isinstance(input_act, Node)
675        input_qspec_map[input_act] = quantization_config.input_activation
676
677        node_tensor = node.meta.get("val")
678        if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
679            continue
680
681    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
682        input_qspec_map=input_qspec_map,
683        output_qspec=quantization_config.output_activation,
684        _annotated=True,
685    )
686
687
688@register_annotator([torch.ops.aten.matmul.default])
689def annotate_matmul(node: Node, quantization_config: QuantizationConfig) -> None:
690    if _is_annotated([node]):
691        return
692
693    input_act_qspec = quantization_config.input_activation
694    output_act_qspec = quantization_config.output_activation
695
696    input_qspec_map = {}
697    input_act0 = node.args[0]
698    if isinstance(input_act0, Node):
699        input_qspec_map[input_act0] = input_act_qspec
700
701    input_act1 = node.args[1]
702    if isinstance(input_act1, Node):
703        # In matmul, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized.
704        if input_act_qspec.dtype == torch.int32:
705            # we should use int16 for mm / bmm instead of int4
706            input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight
707        else:
708            input_qspec_map[input_act1] = input_act_qspec
709
710    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
711        input_qspec_map=input_qspec_map,
712        output_qspec=output_act_qspec,
713        _annotated=True,
714    )
715
716
717@register_annotator([torch.ops.aten.bmm.default])
718def annotate_bmm(node: Node, quantization_config: QuantizationConfig) -> None:
719    if _is_annotated([node]):
720        return
721
722    input_act_qspec = quantization_config.input_activation
723    output_act_qspec = quantization_config.output_activation
724
725    input_qspec_map = {}
726    input_act0 = node.args[0]
727    if isinstance(input_act0, Node):
728        input_qspec_map[input_act0] = input_act_qspec
729
730    input_act1 = node.args[1]
731    if isinstance(input_act1, Node):
732        # In bmm, QNN_DATATYPE_SFIXED_POINT_16 Input1 must have QNN_DATATYPE_UFIXED_POINT_16 Input0 and must be symmetric quantized.
733        if input_act_qspec.dtype == torch.int32:
734            # we should use int16 for mm / bmm instead of int4
735            input_qspec_map[input_act1] = get_16a16w_qnn_ptq_config().weight
736        else:
737            input_qspec_map[input_act1] = input_act_qspec
738
739    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
740        input_qspec_map=input_qspec_map,
741        output_qspec=output_act_qspec,
742        _annotated=True,
743    )
744
745    # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
746    node.meta["source_fn_stack"] = [(node, torch.bmm)]
747
748
749@register_annotator(
750    [
751        torch.ops.aten.conv2d.default,
752        torch.ops.aten.conv1d.default,
753        torch.ops.aten.conv_transpose2d.input,
754    ]
755)
756def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
757    if _is_annotated([node]):
758        return
759
760    input_qspec_map = {}
761    input_act = node.args[0]
762    assert isinstance(input_act, Node)
763    input_spec = quantization_config.input_activation
764    input_qspec_map[input_act] = input_spec
765
766    weight = node.args[1]
767    assert isinstance(weight, Node)
768    input_qspec_map[weight] = quantization_config.weight
769
770    if len(node.args) > 2:
771        bias = node.args[2]
772        if isinstance(bias, Node):
773            if callable(quantization_config.bias):
774                input_qspec_map[bias] = quantization_config.bias(node)
775            else:
776                input_qspec_map[bias] = quantization_config.bias
777
778    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
779        input_qspec_map=input_qspec_map,
780        output_qspec=quantization_config.output_activation,
781        _annotated=True,
782    )
783
784
785@register_annotator([torch.ops.aten.linear.default])
786def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None:
787    act_node = node.args[0]
788    weight_node = node.args[1]
789    bias_node = None
790    if len(node.args) > 2:
791        bias_node = node.args[2]
792
793    if _is_annotated([node]):
794        return
795
796    _annotate_input_qspec_map(
797        node,
798        act_node,
799        quantization_config.input_activation,
800    )
801    _annotate_input_qspec_map(
802        node,
803        weight_node,
804        quantization_config.weight,
805    )
806    nodes_to_mark_annotated = [node, weight_node]
807    if bias_node:
808        if callable(quantization_config.bias):
809            bias_config = quantization_config.bias(node)
810        else:
811            bias_config = quantization_config.bias
812        _annotate_input_qspec_map(node, bias_node, bias_config)
813        nodes_to_mark_annotated.append(bias_node)
814    _annotate_output_qspec(node, quantization_config.output_activation)
815    _mark_nodes_as_annotated(nodes_to_mark_annotated)
816
817    # We use get_source_partition in pass, but it is the same source for MultiheadAttention, so we need to change its source_fn_stack.
818    node.meta["source_fn_stack"] = [(node, torch.nn.Linear)]
819
820
821@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default])
822def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None:
823    act, weight, bias = node.args[0:3]
824    if _is_annotated([node]):
825        return
826
827    _annotate_input_qspec_map(
828        node,
829        act,
830        quantization_config.input_activation,
831    )
832    # QNN requires uint8 instead of int8 in 'weight' config
833    _annotate_input_qspec_map(
834        node,
835        weight,
836        quantization_config.input_activation,
837    )
838    _annotate_input_qspec_map(
839        node,
840        bias,
841        quantization_config.bias,
842    )
843    _annotate_output_qspec(node, quantization_config.output_activation)
844    _mark_nodes_as_annotated([node, *node.args[0:3]])
845
846
847@register_annotator([operator.getitem])
848def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None:
849    if _is_annotated([node]):
850        return
851
852    if _is_float_tensor(node):
853        _annotate_output_qspec(node, quantization_config.output_activation)
854        _mark_nodes_as_annotated([node])
855
856
857@register_annotator([torch.ops.aten.layer_norm.default])
858def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None:
859    act_node = node.args[0]
860    weight_node = node.args[2]
861    bias_node = None
862    if len(node.args) > 2:
863        bias_node = node.args[3]
864
865    if _is_annotated([node]):
866        return
867    input_act_qspec = quantization_config.input_activation
868
869    _annotate_input_qspec_map(
870        node,
871        act_node,
872        input_act_qspec,
873    )
874    if input_act_qspec.dtype == torch.int32:
875        _annotate_input_qspec_map(
876            node,
877            weight_node,
878            get_16a16w_qnn_ptq_config().weight,
879        )
880    else:
881        _annotate_input_qspec_map(
882            node,
883            weight_node,
884            input_act_qspec,
885        )
886    nodes_to_mark_annotated = [node, weight_node]
887    if bias_node:
888        _annotate_input_qspec_map(
889            node,
890            bias_node,
891            quantization_config.bias,
892        )
893        nodes_to_mark_annotated.append(bias_node)
894    _annotate_output_qspec(node, quantization_config.output_activation)
895    _mark_nodes_as_annotated(nodes_to_mark_annotated)
896
897
898@register_annotator([torch.ops.aten.cat.default, torch.ops.aten.concat.default])
899def annotate_cat(node: Node, quantization_config: QuantizationConfig) -> None:
900    input_nodes = node.args[0]
901    if _is_annotated([node]):
902        return
903
904    assert isinstance(input_nodes, Sequence)
905
906    first_input_node = input_nodes[0]
907    input_qspec_map = {}
908    assert isinstance(first_input_node, Node)
909    assert isinstance(node, Node)
910    input_qspec_map[first_input_node] = quantization_config.input_activation
911    share_qparams_with_input_act0_qspec = SharedQuantizationSpec(
912        (first_input_node, node)
913    )
914
915    for input_node in input_nodes[1:]:
916        if input_node not in input_qspec_map:
917            assert isinstance(input_node, Node)
918            input_qspec_map[input_node] = share_qparams_with_input_act0_qspec
919
920    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
921        input_qspec_map=input_qspec_map,
922        output_qspec=share_qparams_with_input_act0_qspec,
923        _annotated=True,
924    )
925
926
927@register_annotator([torch.ops.aten.unbind.int])
928def annotate_unbind(node: Node, quantization_config: QuantizationConfig) -> None:
929    if _is_annotated([node]):
930        return
931
932    input_qspec_map = {}
933    input_act = node.args[0]
934    assert isinstance(input_act, Node)
935    input_qspec_map[input_act] = quantization_config.input_activation
936
937    node_tensor = node.meta.get("val")
938    if torch.is_tensor(node_tensor) and node_tensor.dtype == torch.int64:
939        return
940
941    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
942        input_qspec_map=input_qspec_map,
943        _annotated=True,
944    )
945
946
947@register_annotator([torch.ops.aten.split.Tensor, torch.ops.aten.chunk.default])
948def annotate_chunk(node: Node, quantization_config: QuantizationConfig) -> None:
949    if _is_annotated([node]):
950        return
951
952    input_qspec_map = {}
953    input_act = node.args[0]
954    assert isinstance(input_act, Node)
955    input_qspec_map[input_act] = quantization_config.input_activation
956
957    node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
958        input_qspec_map=input_qspec_map,
959        _annotated=True,
960    )
961
962    for user in node.users:
963        user.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
964            output_qspec=quantization_config.output_activation,
965            _annotated=True,
966        )
967