1# mypy: allow-untyped-defs 2import logging 3from typing import Optional 4 5import torch 6from torch._export.error import InternalError 7from torch.ao.quantization.pt2e.utils import ( 8 _filter_sym_size_users, 9 _find_q_dq_node_for_user, 10 _is_valid_annotation, 11) 12from torch.ao.quantization.quantizer import QuantizationSpecBase 13from torch.fx.passes.infra.pass_base import PassBase, PassResult 14 15 16logger = logging.getLogger(__name__) 17logger.setLevel(logging.ERROR) 18 19__all__ = ["PortNodeMetaForQDQ"] 20 21_METADATA_TO_PORT = [ 22 "stack_trace", 23 "quantization_tag", 24] 25 26_QUANTIZE_OPS = [ 27 torch.ops.quantized_decomposed.quantize_per_tensor.default, 28 torch.ops.quantized_decomposed.quantize_per_tensor.tensor, 29 torch.ops.quantized_decomposed.quantize_per_channel.default, 30] 31 32_DEQUANTIZE_OPS = [ 33 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 34 torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, 35 torch.ops.quantized_decomposed.dequantize_per_channel.default, 36] 37 38_CHOOSE_QPARAMS_OPS = [ 39 torch.ops.quantized_decomposed.choose_qparams.tensor, 40 torch.ops.quantized_decomposed.choose_qparams_symmetric.tensor, 41] 42 43 44def _add_metadata(to_node: torch.fx.Node, from_node: torch.fx.Node) -> None: 45 from_meta = from_node.meta 46 for meta_name in _METADATA_TO_PORT: 47 if meta_name in from_meta: 48 to_node.meta[meta_name] = from_meta[meta_name] 49 50 51def _has_quant_annotation(node: torch.fx.Node) -> bool: 52 return "quantization_annotation" in node.meta 53 54 55def _find_choose_qparams_node(node: torch.fx.Node) -> Optional[torch.fx.Node]: 56 # BFS to look for choose qparams 57 from collections import deque 58 59 queue = deque(list(node.users.keys())) 60 while len(queue): 61 n = queue.popleft() 62 if n.op == "output": 63 continue 64 if n.op == "call_function" and n.target in _CHOOSE_QPARAMS_OPS: 65 return n 66 for k in n.users.keys(): 67 queue.append(k) 68 return None 69 70 71def _port_metadata_for_input_quant_nodes( 72 input_node: torch.fx.Node, 73 node: torch.fx.Node, 74 qspec: Optional[QuantizationSpecBase], 75): 76 if qspec is None: 77 return 78 79 is_dynamic_quant = getattr(qspec, "is_dynamic", None) 80 if is_dynamic_quant is not None and is_dynamic_quant is True: 81 choose_qparams_node = _find_choose_qparams_node(input_node) 82 if choose_qparams_node is None: 83 raise ValueError(f"No chose qparams node found for {node}") 84 choose_qparam_users = _filter_sym_size_users(choose_qparams_node) 85 if len(choose_qparam_users) != 2: 86 raise InternalError(f"Expecting exactly two user for {choose_qparams_node}") 87 scale_node = choose_qparam_users.pop() 88 dynamic_q_node = next(iter(scale_node.users.keys())) 89 dynamic_q_node_users = _filter_sym_size_users(dynamic_q_node) 90 if len(dynamic_q_node_users) > 1: 91 raise InternalError(f"Expecting single user for {dynamic_q_node}") 92 dynamic_dq_node = dynamic_q_node_users.pop() 93 _add_metadata(choose_qparams_node, node) 94 _add_metadata(dynamic_q_node, node) 95 _add_metadata(dynamic_dq_node, node) 96 else: 97 q_node, dq_node = _find_q_dq_node_for_user(input_node, node) 98 if q_node is None or dq_node is None: 99 return 100 # add metadata for all the node between q_node and get_attr node 101 # if the q_node can be traced back to get_attr node 102 q_to_get_attr_nodes = [q_node] 103 q_node_input = q_node.args[0] 104 while ( 105 isinstance(q_node_input, torch.fx.Node) 106 and q_node_input.op == "call_function" 107 and q_node_input.target 108 in [ 109 torch.ops.aten.flatten.using_ints, 110 torch.ops.aten.permute.default, 111 torch.ops.aten.permute_copy.default, 112 torch.ops.aten.slice_copy.Tensor, 113 torch.ops.aten.squeeze.dim, 114 torch.ops.aten.squeeze_copy.dim, 115 torch.ops.aten.transpose.Dimname, 116 torch.ops.aten.transpose.int, 117 torch.ops.aten.transpose_, 118 torch.ops.aten.view_copy.default, 119 torch.ops.aten.view.default, 120 torch.ops.aten._mkldnn_transpose, 121 ] 122 ): 123 q_to_get_attr_nodes.append(q_node_input) 124 q_node_input = q_node_input.args[0] 125 if isinstance(q_node_input, torch.fx.Node) and q_node_input.op == "get_attr": 126 for n in q_to_get_attr_nodes: 127 _add_metadata(n, q_node_input) 128 _add_metadata(dq_node, node) 129 130 131def _port_metadata_for_output_quant_nodes( 132 node: torch.fx.Node, qspec: Optional[QuantizationSpecBase] 133): 134 if qspec is None: 135 return 136 137 node_users = _filter_sym_size_users(node) 138 if len(node.users) == 0: 139 return 140 if len(node_users) != 1: 141 logger.warning(f"Expecting {node} to have single user") # noqa: G004 142 q_node = node_users.pop() 143 if q_node.op != "call_function" or q_node.target not in _QUANTIZE_OPS: 144 logger.warning( 145 f"Expecting {node} user to be a quantized op but got {q_node}" # noqa: G004 146 ) # noqa: G004 147 return 148 149 _add_metadata(q_node, node) 150 151 152class PortNodeMetaForQDQ(PassBase): 153 """ 154 Port metadata for nodes added by quantization flow. 155 For static quant these are: 156 - quantizer_per_tensor.default, dequantize_per_tensor.default 157 - quantizer_per_channel.default, dequantize_per_channel.default 158 For dynamic quant these are: 159 - choose_qparams.tensor 160 - quantizer_per_tensor.tensor, dequantize_per_tensor.tensor 161 - quantizer_per_channel.default, dequantize_per_channel.default 162 163 Rules of porting metadata: 164 - Metadata to be ported: 165 - nn_module_stack 166 - stack_trace 167 - quantization_tag 168 - Metadata to NOT be ported: 169 - Everything else 170 - Rules: 171 - Statically quantized patterns: 172 - Dequantize nodes on the inputs to be quantized inherit metadata of the consumer node. 173 - Quantize nodes on the outputs inherit metadata of the producer node. 174 - Example 1: 175 - Original: [Conv -> AvgPool -> Linear] 176 - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] 177 - Inner brackets specify which nodes Q/DQ inherit metdata from 178 - [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> [DQ -> Linear -> Q] -> DQ] 179 - Note first Q and last DQ do not inherit metadata from any nodes 180 - Example 2: 181 - Original: [Conv -> AvgPool -> Linear] 182 - AvgPool is not quantized 183 - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> Linear -> Q -> DQ] 184 - Inner brackets specify which nodes Q/DQ inherit metdata from 185 - [Q-> [DQ -> Conv -> Q] -> DQ -> [AvgPool] -> Q -> [DQ -> Linear -> Q] -> DQ] 186 - Note DQ and Q nodes around AvgPool do not inherit metadata from AvgPool because 187 AvgPool was not supposed to be quantized. Metadata porting relies on quantization_annotation 188 on the nodes (in this case AvgPool node) to conclude if the node or patter was 189 supposed to be quantized. And subsequntly decide if the preceding Q, if any, should 190 inherit metadata from AvgPool. 191 - Dynamically quantized patterns: 192 - Input that are dynamically quantized have choose_qparams, quantize and dequantize nodes 193 - For example, below linear is dynamically quantized while rest statically: 194 - Original: [Conv -> AvgPool -> Linear] 195 - Quantized [Q-> DQ -> Conv -> Q -> DQ -> AvgPool -> Q -> DQ -> choose_params -> Q -> DQ -> Linear] 196 - Quantized [Q-> [DQ -> Conv -> Q] -> [DQ -> AvgPool -> Q] -> DQ -> [choose_params -> Q -> DQ -> Linear]] 197 - Note first Q does not inherit metadata from any nodes 198 NB: 199 - The best place for porting metadata is during observer conversion to q/dq. This is because it precisely 200 knows which quantization spec is converted to q/dq and thus from where the metadata should be ported. 201 However, since FX and PT2E quant workflow are on a common code-base, this hurts readability quite a bit. 202 Doing it via a separate pass, helps readability of the code. Once we are able to refactor PT2E quant 203 code, this pass should like to be integrated in the refactored variant of "convert" step. 204 """ 205 206 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 207 for node in graph_module.graph.nodes: 208 annotation = node.meta.get("quantization_annotation", None) 209 if _is_valid_annotation(annotation): 210 input_qspec_map = node.meta["quantization_annotation"].input_qspec_map 211 output_qspec = node.meta["quantization_annotation"].output_qspec 212 for input_node, qspec in input_qspec_map.items(): 213 _port_metadata_for_input_quant_nodes(input_node, node, qspec) 214 _port_metadata_for_output_quant_nodes(node, output_qspec) 215 return PassResult(graph_module, True) 216