xref: /aosp_15_r20/external/executorch/backends/example/example_operators/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from torch.ao.quantization.quantizer.quantizer import QuantizationAnnotation
8
9
10def _nodes_are_annotated(node_list):
11    for node in node_list:
12        quantization_annotation = node.meta.get("quantization_annotation", None)
13        if not quantization_annotation:
14            return False
15        if quantization_annotation._annotated:
16            continue
17        else:
18            return False
19    return True
20
21
22def _annotate_nodes(node_tuples, quant_spec, input_node=False):
23    for node_tuple in node_tuples:
24        node = node_tuple[0]
25        quant_annotation = node.meta.get(
26            "quantization_annotation", QuantizationAnnotation(_annotated=True)
27        )
28        if input_node:
29            input_node = node_tuple[1]
30            quant_annotation.input_qspec_map[input_node] = quant_spec
31        else:
32            quant_annotation.output_qspec = quant_spec
33        node.meta["quantization_annotation"] = quant_annotation
34