xref: /aosp_15_r20/external/executorch/backends/cadence/aot/quantizer/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import itertools
10from collections import OrderedDict
11from math import frexp, isclose, trunc
12from typing import Any, Dict, List, Tuple, Type
13
14import torch
15from torch import fx
16from torch._ops import OpOverload
17from torch.ao.quantization import ObserverOrFakeQuantize
18
19from torch.fx import GraphModule
20from torch.fx.passes.utils.source_matcher_utils import (
21    check_subgraphs_connected,
22    SourcePartition,
23)
24
25
26def quantize_tensor_multiplier(
27    requantize_scale_tensor: torch.Tensor,
28) -> Tuple[torch.Tensor, torch.Tensor]:
29    """
30    Given requantize_scale_tensor with values in the interval (0, 1),
31    produce a pair of tensors (out_multiplier, right_shift) where out_multiplier
32    is an int32 tensor representing fixed-point values in the interval [-1, 1),
33    and right_shift is an amount to shift right by, so that the floating-point
34    multiplication of some int32 input with each value of requantize_scale_tensor:
35        result = int32_value * requantize_scale_tensors[i]
36    is best approximated by the integer-arithmetic-only code:
37        result = RoundingRightShift(FixedPointMultiplication(int32_value,
38                                    out_multiplier[i]), right_shift[i])
39    """
40
41    # This is identical to C++11 std::round(). The general python round rounds
42    # down, and C++ rounds away from zero.
43    def round_away_zero(f) -> int:
44        r = -0.5 if (f < 0) else 0.5
45        return trunc(f + r)
46
47    def quantize_scalar_multiplier(requantize_scale: float) -> Tuple[int, int]:
48        significand, exponent = frexp(requantize_scale)
49        significand_q31 = int(round_away_zero(significand * (1 << 31)))
50        # Handle the special case when the real multiplier was so close to 1
51        # that its fixed-point approximation was indistinguishable from 1.
52        # We handle this by dividing it by two, incrementing exponent by 1.
53        # the right shift amount.
54        if significand_q31 == (1 << 31):
55            significand_q31 //= 2
56            exponent += 1
57
58        # Verify that the decomposition of requantize_scale into significand
59        # and exponent is correct.
60        reconstructed = significand_q31 / (1 << 31) * pow(2, exponent)
61        assert isclose(
62            requantize_scale, reconstructed, rel_tol=1e-4, abs_tol=1e-4
63        ), "computation of significand and exponent from requantize_scale is not accurate"
64
65        return (significand_q31, exponent)
66
67    # Flatten the input scale tensor so that we can operate on individual values
68    orig_shape = requantize_scale_tensor.shape
69    flattened_tensor = requantize_scale_tensor.flatten().to(torch.float32)
70    out_multiplier = torch.zeros(flattened_tensor.shape, dtype=torch.int32)
71    right_shift = torch.zeros(flattened_tensor.shape, dtype=torch.int32)
72
73    # Iterate over the flattened scale tensor and compute the decomposition of
74    # each value in scale tensor into significand(out_multiplier) and
75    # exponent(right_shift)
76    for idx, scale in enumerate(flattened_tensor):
77        (si, ex) = quantize_scalar_multiplier(scale)
78        out_multiplier[idx], right_shift[idx] = si, ex
79
80    # Reshape the tensors back to the original shape
81    out_multiplier = out_multiplier.reshape(orig_shape)
82    right_shift = right_shift.reshape(orig_shape)
83
84    return (out_multiplier, right_shift)
85
86
87def is_annotated(nodes: List[fx.Node]) -> bool:
88    annotated = False
89    for node in nodes:
90        annotated = annotated or (
91            "quantization_annotation" in node.meta
92            and node.meta["quantization_annotation"]._annotated
93        )
94    return annotated
95
96
97def no_outside_users(fused_partition) -> bool:
98    """
99    Checks if each partition other than the last does not have any outside users.
100    """
101    for source_partition in fused_partition[:-1]:
102        if len(source_partition.output_nodes) != 1:
103            return False
104        if len(source_partition.output_nodes[0].users) != 1:
105            return False
106    return True
107
108
109def create_zero_bias_int32(
110    graph_module: GraphModule,
111    weight_node: fx.Node,
112    bias_scale: float,
113) -> fx.Node:
114    """
115    Creates a zero bias tensor with the shape of weight[0]
116    """
117    attr_node = getattr(graph_module, weight_node.target)
118    weight_shape = list(attr_node.shape)
119    bias_shape = weight_shape[0]
120    return graph_module.graph.call_function(
121        torch.ops.aten.full.default,
122        ([bias_shape], 0.0),
123        {"dtype": torch.int32},
124    )
125
126
127def get_bias_qparams(
128    obs_or_fqs: List[ObserverOrFakeQuantize],
129) -> Tuple[torch.Tensor, torch.Tensor]:
130    act_scale, _ = obs_or_fqs[0].calculate_qparams()
131    weight_scale, _ = obs_or_fqs[1].calculate_qparams()
132    bias_scale = act_scale * weight_scale
133    bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32)
134    return bias_scale, bias_zero_point
135
136
137def get_conv_args(arg, first_val: int) -> List[fx.Node]:
138    return arg if len(arg) == 2 else [first_val, arg[0]]
139
140
141def get_aten_node_target_partitions(
142    graph: torch.fx.Graph,
143    wanted_original_aten_op: List[OpOverload],
144):
145    """
146    Args:
147        graph: The graph we want to partition
148        wanted_original_aten_op: List of original_aten ops (OpOverload)
149
150    Returns:
151        Dictionary mapping aten ops that were given to a list of SourcePartitions
152        that correspond to the list of nodes that were decomposed from the given
153        aten ops.
154    """
155    modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {}
156
157    for node in graph.nodes:
158        # The metadata source_fn should contain a tuple of a unique name for the
159        # source, and the source function if the node is decomposed from a
160        # function, or the type of module if the node is decomposed from a leaf
161        # module
162        # TODO(matthiascremon): look into ways to avoid using source_fn_stack
163        if (source_fn_st := node.meta.get("source_fn_stack")) is None:
164            continue
165
166        source_fn = source_fn_st[-1]
167        if node.target not in wanted_original_aten_op:
168            continue
169
170        diff_modules = modules.setdefault(source_fn[1], {})
171        partition = diff_modules.setdefault(node.name, [])
172        partition.append(node)
173
174    def make_partition(
175        nodes: List[torch.fx.Node], module_type: Type
176    ) -> SourcePartition:
177        input_nodes = set()
178        output_nodes = set()
179        params = set()
180        for node in nodes:
181            for arg in node.args:
182                if isinstance(arg, torch.fx.Node) and arg not in nodes:
183                    input_nodes.add(arg)
184
185            if node.op == "get_attr":
186                params.add(node)
187
188            for user in node.users.keys():
189                if user not in nodes:
190                    output_nodes.add(node)
191
192        return SourcePartition(
193            nodes,
194            module_type,
195            list(input_nodes),
196            list(output_nodes),
197            list(params),  # type: ignore[arg-type]
198        )
199
200    ret: Dict[Type[Any], List[SourcePartition]] = {}
201
202    for k, v in modules.items():
203        ret[k] = [make_partition(partition, k) for partition in v.values()]
204
205    return ret
206
207
208def _partitions_sequential(partitions: Tuple[SourcePartition]) -> bool:
209    prev_partition = None
210    for partition in partitions:
211        if prev_partition is not None and not check_subgraphs_connected(
212            prev_partition, partition
213        ):
214            return False
215        prev_partition = partition
216    return True
217
218
219def find_sequential_partitions_aten(
220    gm: torch.fx.GraphModule,
221    partition_types: List[Any],
222):
223    typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict()
224    for partition_type in partition_types:
225        partitions = get_aten_node_target_partitions(gm.graph, [partition_type])
226        typed_partitions[partition_type] = list(
227            itertools.chain.from_iterable(partitions.values())
228        )
229
230    typed_partitions_list = list(typed_partitions.values())
231    fusion_candidates = itertools.product(*typed_partitions_list)
232    fused_partitions = []
233    for candidate in fusion_candidates:
234        if _partitions_sequential(candidate):
235            fused_partitions.append(candidate)
236    return fused_partitions
237