xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/pt2e/port_metadata_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import logging
3from typing import Optional
4
5import torch
6from torch._export.error import InternalError
7from torch.ao.quantization.pt2e.utils import (
8    _filter_sym_size_users,
9    _find_q_dq_node_for_user,
10    _is_valid_annotation,
11)
12from torch.ao.quantization.quantizer import QuantizationSpecBase
13from torch.fx.passes.infra.pass_base import PassBase, PassResult
14
15
16logger = logging.getLogger(__name__)
17logger.setLevel(logging.ERROR)
18
19__all__ = ["PortNodeMetaForQDQ"]
20
21_METADATA_TO_PORT = [
22    "stack_trace",
23    "quantization_tag",
24]
25
26_QUANTIZE_OPS = [
27    torch.ops.quantized_decomposed.quantize_per_tensor.default,
28    torch.ops.quantized_decomposed.quantize_per_tensor.tensor,
29    torch.ops.quantized_decomposed.quantize_per_channel.default,
30]
31
32_DEQUANTIZE_OPS = [
33    torch.ops.quantized_decomposed.dequantize_per_tensor.default,
34    torch.ops.quantized_decomposed.dequantize_per_tensor.tensor,
35    torch.ops.quantized_decomposed.dequantize_per_channel.default,
36]
37
38_CHOOSE_QPARAMS_OPS = [
39    torch.ops.quantized_decomposed.choose_qparams.tensor,
40    torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor,
41]
42
43
44def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None:
45    from_meta = from_node.meta
46    for meta_name in _METADATA_TO_PORT:
47        if meta_name in from_meta:
48            to_node.meta[meta_name] = from_meta[meta_name]
49
50
51def _has_quant_annotation(node: torch.fx.Node) -> bool:
52    return "quantization_annotation" in node.meta
53
54
55def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]:
56    # BFS to look for choose qparams
57    from collections import deque
58
59    queue = deque(list(node.users.keys()))
60    while len(queue):
61        n = queue.popleft()
62        if n.op == "output":
63            continue
64        if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS:
65            return n
66        for k in n.users.keys():
67            queue.append(k)
68    return None
69
70
71def _port_metadata_for_input_quant_nodes(
72    input_node: torch.fx.Node,
73    node: torch.fx.Node,
74    qspec: Optional[QuantizationSpecBase],
75):
76    if qspec is None:
77        return
78
79    is_dynamic_quant = getattr(qspec, "is_dynamic", None)
80    if is_dynamic_quant is not None and is_dynamic_quant is True:
81        choose_qparams_node = _find_choose_qparams_node(input_node)
82        if choose_qparams_node is None:
83            raise ValueError(f"No chose qparams node found for {node}")
84        choose_qparam_users = _filter_sym_size_users(choose_qparams_node)
85        if len(choose_qparam_users) != 2:
86            raise InternalError(f"Expecting exactly two user for {choose_qparams_node}")
87        scale_node = choose_qparam_users.pop()
88        dynamic_q_node = next(iter(scale_node.users.keys()))
89        dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node)
90        if len(dynamic_q_node_users) > 1:
91            raise InternalError(f"Expecting single user for {dynamic_q_node}")
92        dynamic_dq_node = dynamic_q_node_users.pop()
93        _add_metadata(choose_qparams_node, node)
94        _add_metadata(dynamic_q_node, node)
95        _add_metadata(dynamic_dq_node, node)
96    else:
97        q_node, dq_node = _find_q_dq_node_for_user(input_node, node)
98        if q_node is None or dq_node is None:
99            return
100        # add metadata for all the node between q_node and get_attr node
101        # if the q_node can be traced back to get_attr node
102        q_to_get_attr_nodes = [q_node]
103        q_node_input = q_node.args[0]
104        while (
105            isinstance(q_node_input, torch.fx.Node)
106            and q_node_input.op == "call_function"
107            and q_node_input.target
108            in [
109                torch.ops.aten.flatten.using_ints,
110                torch.ops.aten.permute.default,
111                torch.ops.aten.permute_copy.default,
112                torch.ops.aten.slice_copy.Tensor,
113                torch.ops.aten.squeeze.dim,
114                torch.ops.aten.squeeze_copy.dim,
115                torch.ops.aten.transpose.Dimname,
116                torch.ops.aten.transpose.int,
117                torch.ops.aten.transpose_,
118                torch.ops.aten.view_copy.default,
119                torch.ops.aten.view.default,
120                torch.ops.aten._mkldnn_transpose,
121            ]
122        ):
123            q_to_get_attr_nodes.append(q_node_input)
124            q_node_input = q_node_input.args[0]
125        if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr":
126            for n in q_to_get_attr_nodes:
127                _add_metadata(n, q_node_input)
128        _add_metadata(dq_node, node)
129
130
131def _port_metadata_for_output_quant_nodes(
132    node: torch.fx.Node, qspec: Optional[QuantizationSpecBase]
133):
134    if qspec is None:
135        return
136
137    node_users = _filter_sym_size_users(node)
138    if len(node.users) == 0:
139        return
140    if len(node_users) != 1:
141        logger.warning(f"Expecting {node} to have single user")  # noqa: G004
142    q_node = node_users.pop()
143    if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS:
144        logger.warning(
145            f"Expecting {node} user to be a quantized op but got {q_node}"  # noqa: G004
146        )  # noqa: G004
147        return
148
149    _add_metadata(q_node, node)
150
151
152class PortNodeMetaForQDQ(PassBase):
153    """
154    Port metadata for nodes added by quantization flow.
155    For static quant these are:
156    - quantizer_per_tensor.default, dequantize_per_tensor.default
157    - quantizer_per_channel.default, dequantize_per_channel.default
158    For dynamic quant these are:
159    - choose_qparams.tensor
160    - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor
161    - quantizer_per_channel.default, dequantize_per_channel.default
162
163    Rules of porting metadata:
164    - Metadata to be ported:
165      - nn_module_stack
166      - stack_trace
167      - quantization_tag
168    - Metadata to NOT be ported:
169      - Everything else
170    - Rules:
171      - Statically quantized patterns:
172        - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node.
173        - Quantize nodes on the outputs inherit metadata of the producer node.
174        - Example 1:
175          - Original: [Conv -> AvgPool -> Linear]
176          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
177          - Inner brackets specify which nodes Q/DQ inherit metdata from
178          - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ]
179          - Note first Q and last DQ do not inherit metadata from any nodes
180        - Example 2:
181          - Original: [Conv -> AvgPool -> Linear]
182          - AvgPool is not quantized
183          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ]
184          - Inner brackets specify which nodes Q/DQ inherit metdata from
185          - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ]
186          - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because
187            AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation
188            on the nodes (in this case AvgPool node) to conclude if the node or patter was
189            supposed to be quantized. And subsequntly decide if the preceding Q, if any, should
190            inherit metadata from AvgPool.
191      - Dynamically quantized patterns:
192        - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes
193        - For example, below linear is dynamically quantized while rest statically:
194          - Original: [Conv -> AvgPool -> Linear]
195          - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear]
196          - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]]
197          - Note first Q does not inherit metadata from any nodes
198    NB:
199    - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely
200      knows which quantization spec is converted to q/dq and thus from where the metadata should be ported.
201      However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit.
202      Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant
203      code, this pass should like to be integrated in the refactored variant of "convert" step.
204    """
205
206    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
207        for node in graph_module.graph.nodes:
208            annotation = node.meta.get("quantization_annotation", None)
209            if _is_valid_annotation(annotation):
210                input_qspec_map = node.meta["quantization_annotation"].input_qspec_map
211                output_qspec = node.meta["quantization_annotation"].output_qspec
212                for input_node, qspec in input_qspec_map.items():
213                    _port_metadata_for_input_quant_nodes(input_node, node, qspec)
214                _port_metadata_for_output_quant_nodes(node, output_qspec)
215        return PassResult(graph_module, True)
216