1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# Copyright 2024 Arm Limited and/or its affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8# pyre-unsafe 9 10# 11# Utility functions for ArmQuantizer 12# 13 14import operator 15from typing import Callable, cast, List 16 17import torch 18from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig 19from torch._subclasses import FakeTensor 20 21from torch.ao.quantization.quantizer import ( 22 QuantizationAnnotation, 23 SharedQuantizationSpec, 24) 25from torch.fx import GraphModule, Node 26 27 28def is_annotated(node: Node) -> bool: 29 """Given a node return whether the node is annotated.""" 30 return ( 31 "quantization_annotation" in node.meta 32 and cast( 33 QuantizationAnnotation, node.meta["quantization_annotation"] 34 )._annotated 35 ) 36 37 38def are_annotated(nodes: List[Node]) -> bool: 39 """Given a list of nodes (that represents an operator pattern), 40 return True if any of the nodes 41 is annotated, otherwise return False. 42 """ 43 for node in nodes: 44 if is_annotated(node): 45 return True 46 return False 47 48 49def mark_nodes_as_annotated(nodes: List[Node]) -> None: 50 """Marks all nodes in list 'nodes' as annotated. If needed, an empty 51 QuantizationAnnotation is added to the quantization_annotation node meta entry. 52 """ 53 for node in nodes: 54 if node is not None: 55 if "quantization_annotation" not in node.meta: 56 node.meta["quantization_annotation"] = QuantizationAnnotation() 57 node.meta["quantization_annotation"]._annotated = True 58 59 60def get_shared_qspec( 61 node: Node, gm: GraphModule, quantization_config: QuantizationConfig 62): 63 """Returns a Quantization constallation with a SharedQuantizationSpec for the inputs 64 and output to the parameter 'node'. 65 Parameters: 66 node: a node with two inputs that should share Quantization parameters. 67 gm: The GraphModule containing the node. Used to inspect global graph features. 68 quantization_config : a QuantizationConfig with the input QuantizationSpec to share 69 Returns: 70 input_qspec_map: a dict[node, QuantizationSpec] that maps the inputs to 'node' to 71 the correct QuantizationSpec. 72 shared_with_input0_spec: The SharedQuantizationSpec to be used as output QuantizationSpec. 73 74 Both outputs are None if one of the inputs is a node that can't be quantized. 75 """ 76 input_act0 = cast(Node, node.args[0]) 77 input_act1 = node.args[1] 78 79 input_act_qspec = quantization_config.get_input_act_qspec() 80 shared_with_input0_qspec = SharedQuantizationSpec((input_act0, node)) 81 82 input_qspec_map = {} 83 if isinstance(input_act0, Node): 84 if not is_input_ok_for_quantization(input_act0, gm): 85 return None, None 86 input_qspec_map[input_act0] = input_act_qspec 87 88 if isinstance(input_act1, Node): 89 if not is_input_ok_for_quantization(input_act1, gm): 90 return None, None 91 if input_act0 is not input_act1: 92 input_qspec_map[input_act1] = shared_with_input0_qspec 93 return input_qspec_map, shared_with_input0_qspec 94 95 96def is_input_ok_for_quantization(input_act: Node, gm: GraphModule): 97 """Check if an input can be quantized. The input can not be quantized if: 98 - The node does not output a float tensor or, 99 - The node outputs a large scalar. 100 """ 101 return not ( 102 is_input_non_float_tensor(input_act) or is_input_large_scalar(input_act, gm) 103 ) 104 105 106def get_node_target(module: torch.nn.Module | GraphModule, target_str: str): 107 targets = target_str.split(".") 108 for target in targets[:-1]: 109 module = module.get_submodule(target) 110 return getattr(module, targets[-1]) 111 112 113def is_input_large_scalar(node: Node, gm: GraphModule): 114 """Check if input is a large scalar value. So that we can skip quantization for the node 115 since histc op (in HistogramObserver) only works for values up to certain upper bound 116 """ 117 if node.op == "get_attr" and isinstance(node.target, str): 118 tensor = get_node_target(gm, node.target) 119 # torch.histc works until this upper bound 120 HISTC_UPPER_BOUND = 3.4028235e15 121 return tensor.numel() == 1 and abs(tensor.item()) > HISTC_UPPER_BOUND 122 return False 123 124 125def is_input_non_float_tensor(node: Node) -> bool: 126 """Check if the input is not a float tensor, so that we can skip quantization for the node 127 since observers only works with float Tensors 128 """ 129 if "val" not in node.meta or not isinstance(node.meta["val"], FakeTensor): 130 return True 131 return node.meta["val"].dtype != torch.float32 132 133 134def is_share_obs_or_fq_op(op: Callable) -> bool: 135 """Returns whether the the operation 'op' can be quantized using a shared observer or 136 fake quantizer. This means that the operation can inherit it's quantization spec 137 from parent nodes. 138 """ 139 return op in [ 140 torch.ops.aten.hardtanh.default, 141 torch.ops.aten.hardtanh_.default, 142 torch.ops.aten.relu.default, 143 torch.ops.aten.mean.default, 144 torch.ops.aten.mean.dim, 145 torch.ops.aten.permute.default, 146 torch.ops.aten.permute_copy.default, 147 # TODO: remove? 148 torch.ops.aten.adaptive_avg_pool2d.default, 149 torch.ops.aten.avg_pool2d.default, 150 torch.ops.aten.max_pool2d.default, 151 torch.ops.aten.full.default, 152 torch.ops.aten.flatten.using_ints, 153 torch.ops.aten.dropout.default, 154 operator.getitem, 155 ] 156 157 158def propagate_annotation(model: GraphModule) -> None: 159 """For unannotated ops that can share observer or have fake quantizers, 160 annotate with a SharedQuantizationSpec, where the shared spec is the 161 output spec of the parent node. 162 This propagates output qspecs downward in the graph until 163 an op that is already annotated or can't share qspec is encountered. 164 """ 165 for n in model.graph.nodes: 166 n = cast(Node, n) 167 if is_annotated(n): 168 continue 169 if n.op != "call_function" or not is_share_obs_or_fq_op( 170 cast(Callable, n.target) 171 ): 172 continue 173 174 prev_node = n.args[0] 175 if not isinstance(prev_node, Node): 176 continue 177 178 quantization_annotation = cast( 179 QuantizationAnnotation | None, 180 prev_node.meta.get("quantization_annotation", None), 181 ) 182 if not quantization_annotation or not quantization_annotation.output_qspec: 183 continue 184 185 # propagate the previous output_qspec to the current node 186 shared_qspec = SharedQuantizationSpec(prev_node) 187 n.meta["quantization_annotation"] = QuantizationAnnotation( 188 input_qspec_map={ 189 prev_node: shared_qspec, 190 }, 191 output_qspec=shared_qspec, 192 _annotated=True, 193 ) 194