xref: /aosp_15_r20/external/executorch/backends/arm/quantizer/arm_quantizer_utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# Copyright 2024 Arm Limited and/or its affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8# pyre-unsafe
9
10#
11# Utility functions for ArmQuantizer
12#
13
14import operator
15from typing import Callable, cast, List
16
17import torch
18from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig
19from torch._subclasses import FakeTensor
20
21from torch.ao.quantization.quantizer import (
22    QuantizationAnnotation,
23    SharedQuantizationSpec,
24)
25from torch.fx import GraphModule, Node
26
27
28def is_annotated(node: Node) -> bool:
29    """Given a node return whether the node is annotated."""
30    return (
31        "quantization_annotation" in node.meta
32        and cast(
33            QuantizationAnnotation, node.meta["quantization_annotation"]
34        )._annotated
35    )
36
37
38def are_annotated(nodes: List[Node]) -> bool:
39    """Given a list of nodes (that represents an operator pattern),
40    return True if any of the nodes
41    is annotated, otherwise return False.
42    """
43    for node in nodes:
44        if is_annotated(node):
45            return True
46    return False
47
48
49def mark_nodes_as_annotated(nodes: List[Node]) -> None:
50    """Marks all nodes in list 'nodes' as annotated. If needed, an empty
51    QuantizationAnnotation is added to the quantization_annotation node meta entry.
52    """
53    for node in nodes:
54        if node is not None:
55            if "quantization_annotation" not in node.meta:
56                node.meta["quantization_annotation"] = QuantizationAnnotation()
57            node.meta["quantization_annotation"]._annotated = True
58
59
60def get_shared_qspec(
61    node: Node, gm: GraphModule, quantization_config: QuantizationConfig
62):
63    """Returns a Quantization constallation with a SharedQuantizationSpec for the inputs
64    and output to the parameter 'node'.
65    Parameters:
66        node: a node with two inputs that should share Quantization parameters.
67        gm: The GraphModule containing the node. Used to inspect global graph features.
68        quantization_config : a QuantizationConfig with the input QuantizationSpec to share
69    Returns:
70        input_qspec_map: a dict[node, QuantizationSpec] that maps the inputs to 'node' to
71            the correct QuantizationSpec.
72        shared_with_input0_spec: The SharedQuantizationSpec to be used as output QuantizationSpec.
73
74        Both outputs are None if one of the inputs is a node that can't be quantized.
75    """
76    input_act0 = cast(Node, node.args[0])
77    input_act1 = node.args[1]
78
79    input_act_qspec = quantization_config.get_input_act_qspec()
80    shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node))
81
82    input_qspec_map = {}
83    if isinstance(input_act0, Node):
84        if not is_input_ok_for_quantization(input_act0, gm):
85            return None, None
86        input_qspec_map[input_act0] = input_act_qspec
87
88    if isinstance(input_act1, Node):
89        if not is_input_ok_for_quantization(input_act1, gm):
90            return None, None
91        if input_act0 is not input_act1:
92            input_qspec_map[input_act1] = shared_with_input0_qspec
93    return input_qspec_map, shared_with_input0_qspec
94
95
96def is_input_ok_for_quantization(input_act: Node, gm: GraphModule):
97    """Check if an input can be quantized. The input can not be quantized if:
98    - The node does not output a float tensor or,
99    - The node outputs a large scalar.
100    """
101    return not (
102        is_input_non_float_tensor(input_act) or is_input_large_scalar(input_act, gm)
103    )
104
105
106def get_node_target(module: torch.nn.Module | GraphModule, target_str: str):
107    targets = target_str.split(".")
108    for target in targets[:-1]:
109        module = module.get_submodule(target)
110    return getattr(module, targets[-1])
111
112
113def is_input_large_scalar(node: Node, gm: GraphModule):
114    """Check if input is a large scalar value. So that we can skip quantization for the node
115    since histc op (in HistogramObserver) only works for values up to certain upper bound
116    """
117    if node.op == "get_attr" and isinstance(node.target, str):
118        tensor = get_node_target(gm, node.target)
119        # torch.histc works until this upper bound
120        HISTC_UPPER_BOUND = 3.4028235e15
121        return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND
122    return False
123
124
125def is_input_non_float_tensor(node: Node) -> bool:
126    """Check if the input is not a float tensor, so that we can skip quantization for the node
127    since observers only works with float Tensors
128    """
129    if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor):
130        return True
131    return node.meta["val"].dtype != torch.float32
132
133
134def is_share_obs_or_fq_op(op: Callable) -> bool:
135    """Returns whether the the operation 'op' can be quantized using a shared observer or
136    fake quantizer. This means that the operation can inherit it's quantization spec
137    from parent nodes.
138    """
139    return op in [
140        torch.ops.aten.hardtanh.default,
141        torch.ops.aten.hardtanh_.default,
142        torch.ops.aten.relu.default,
143        torch.ops.aten.mean.default,
144        torch.ops.aten.mean.dim,
145        torch.ops.aten.permute.default,
146        torch.ops.aten.permute_copy.default,
147        # TODO: remove?
148        torch.ops.aten.adaptive_avg_pool2d.default,
149        torch.ops.aten.avg_pool2d.default,
150        torch.ops.aten.max_pool2d.default,
151        torch.ops.aten.full.default,
152        torch.ops.aten.flatten.using_ints,
153        torch.ops.aten.dropout.default,
154        operator.getitem,
155    ]
156
157
158def propagate_annotation(model: GraphModule) -> None:
159    """For unannotated ops that can share observer or have fake quantizers,
160    annotate with a SharedQuantizationSpec, where the shared spec is the
161    output spec of the parent node.
162    This propagates output qspecs downward in the graph until
163    an op that is already annotated or can't share qspec is encountered.
164    """
165    for n in model.graph.nodes:
166        n = cast(Node, n)
167        if is_annotated(n):
168            continue
169        if n.op != "call_function" or not is_share_obs_or_fq_op(
170            cast(Callable, n.target)
171        ):
172            continue
173
174        prev_node = n.args[0]
175        if not isinstance(prev_node, Node):
176            continue
177
178        quantization_annotation = cast(
179            QuantizationAnnotation | None,
180            prev_node.meta.get("quantization_annotation", None),
181        )
182        if not quantization_annotation or not quantization_annotation.output_qspec:
183            continue
184
185        # propagate the previous output_qspec to the current node
186        shared_qspec = SharedQuantizationSpec(prev_node)
187        n.meta["quantization_annotation"] = QuantizationAnnotation(
188            input_qspec_map={
189                prev_node: shared_qspec,
190            },
191            output_qspec=shared_qspec,
192            _annotated=True,
193        )
194