1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import itertools 10from collections import OrderedDict 11from math import frexp, isclose, trunc 12from typing import Any, Dict, List, Tuple, Type 13 14import torch 15from torch import fx 16from torch._ops import OpOverload 17from torch.ao.quantization import ObserverOrFakeQuantize 18 19from torch.fx import GraphModule 20from torch.fx.passes.utils.source_matcher_utils import ( 21 check_subgraphs_connected, 22 SourcePartition, 23) 24 25 26def quantize_tensor_multiplier( 27 requantize_scale_tensor: torch.Tensor, 28) -> Tuple[torch.Tensor, torch.Tensor]: 29 """ 30 Given requantize_scale_tensor with values in the interval (0, 1), 31 produce a pair of tensors (out_multiplier, right_shift) where out_multiplier 32 is an int32 tensor representing fixed-point values in the interval [-1, 1), 33 and right_shift is an amount to shift right by, so that the floating-point 34 multiplication of some int32 input with each value of requantize_scale_tensor: 35 result = int32_value * requantize_scale_tensors[i] 36 is best approximated by the integer-arithmetic-only code: 37 result = RoundingRightShift(FixedPointMultiplication(int32_value, 38 out_multiplier[i]), right_shift[i]) 39 """ 40 41 # This is identical to C++11 std::round(). The general python round rounds 42 # down, and C++ rounds away from zero. 43 def round_away_zero(f) -> int: 44 r = -0.5 if (f < 0) else 0.5 45 return trunc(f + r) 46 47 def quantize_scalar_multiplier(requantize_scale: float) -> Tuple[int, int]: 48 significand, exponent = frexp(requantize_scale) 49 significand_q31 = int(round_away_zero(significand * (1 << 31))) 50 # Handle the special case when the real multiplier was so close to 1 51 # that its fixed-point approximation was indistinguishable from 1. 52 # We handle this by dividing it by two, incrementing exponent by 1. 53 # the right shift amount. 54 if significand_q31 == (1 << 31): 55 significand_q31 //= 2 56 exponent += 1 57 58 # Verify that the decomposition of requantize_scale into significand 59 # and exponent is correct. 60 reconstructed = significand_q31 / (1 << 31) * pow(2, exponent) 61 assert isclose( 62 requantize_scale, reconstructed, rel_tol=1e-4, abs_tol=1e-4 63 ), "computation of significand and exponent from requantize_scale is not accurate" 64 65 return (significand_q31, exponent) 66 67 # Flatten the input scale tensor so that we can operate on individual values 68 orig_shape = requantize_scale_tensor.shape 69 flattened_tensor = requantize_scale_tensor.flatten().to(torch.float32) 70 out_multiplier = torch.zeros(flattened_tensor.shape, dtype=torch.int32) 71 right_shift = torch.zeros(flattened_tensor.shape, dtype=torch.int32) 72 73 # Iterate over the flattened scale tensor and compute the decomposition of 74 # each value in scale tensor into significand(out_multiplier) and 75 # exponent(right_shift) 76 for idx, scale in enumerate(flattened_tensor): 77 (si, ex) = quantize_scalar_multiplier(scale) 78 out_multiplier[idx], right_shift[idx] = si, ex 79 80 # Reshape the tensors back to the original shape 81 out_multiplier = out_multiplier.reshape(orig_shape) 82 right_shift = right_shift.reshape(orig_shape) 83 84 return (out_multiplier, right_shift) 85 86 87def is_annotated(nodes: List[fx.Node]) -> bool: 88 annotated = False 89 for node in nodes: 90 annotated = annotated or ( 91 "quantization_annotation" in node.meta 92 and node.meta["quantization_annotation"]._annotated 93 ) 94 return annotated 95 96 97def no_outside_users(fused_partition) -> bool: 98 """ 99 Checks if each partition other than the last does not have any outside users. 100 """ 101 for source_partition in fused_partition[:-1]: 102 if len(source_partition.output_nodes) != 1: 103 return False 104 if len(source_partition.output_nodes[0].users) != 1: 105 return False 106 return True 107 108 109def create_zero_bias_int32( 110 graph_module: GraphModule, 111 weight_node: fx.Node, 112 bias_scale: float, 113) -> fx.Node: 114 """ 115 Creates a zero bias tensor with the shape of weight[0] 116 """ 117 attr_node = getattr(graph_module, weight_node.target) 118 weight_shape = list(attr_node.shape) 119 bias_shape = weight_shape[0] 120 return graph_module.graph.call_function( 121 torch.ops.aten.full.default, 122 ([bias_shape], 0.0), 123 {"dtype": torch.int32}, 124 ) 125 126 127def get_bias_qparams( 128 obs_or_fqs: List[ObserverOrFakeQuantize], 129) -> Tuple[torch.Tensor, torch.Tensor]: 130 act_scale, _ = obs_or_fqs[0].calculate_qparams() 131 weight_scale, _ = obs_or_fqs[1].calculate_qparams() 132 bias_scale = act_scale * weight_scale 133 bias_zero_point = torch.zeros_like(bias_scale, dtype=torch.int32) 134 return bias_scale, bias_zero_point 135 136 137def get_conv_args(arg, first_val: int) -> List[fx.Node]: 138 return arg if len(arg) == 2 else [first_val, arg[0]] 139 140 141def get_aten_node_target_partitions( 142 graph: torch.fx.Graph, 143 wanted_original_aten_op: List[OpOverload], 144): 145 """ 146 Args: 147 graph: The graph we want to partition 148 wanted_original_aten_op: List of original_aten ops (OpOverload) 149 150 Returns: 151 Dictionary mapping aten ops that were given to a list of SourcePartitions 152 that correspond to the list of nodes that were decomposed from the given 153 aten ops. 154 """ 155 modules: Dict[Type, Dict[str, List[torch.fx.Node]]] = {} 156 157 for node in graph.nodes: 158 # The metadata source_fn should contain a tuple of a unique name for the 159 # source, and the source function if the node is decomposed from a 160 # function, or the type of module if the node is decomposed from a leaf 161 # module 162 # TODO(matthiascremon): look into ways to avoid using source_fn_stack 163 if (source_fn_st := node.meta.get("source_fn_stack")) is None: 164 continue 165 166 source_fn = source_fn_st[-1] 167 if node.target not in wanted_original_aten_op: 168 continue 169 170 diff_modules = modules.setdefault(source_fn[1], {}) 171 partition = diff_modules.setdefault(node.name, []) 172 partition.append(node) 173 174 def make_partition( 175 nodes: List[torch.fx.Node], module_type: Type 176 ) -> SourcePartition: 177 input_nodes = set() 178 output_nodes = set() 179 params = set() 180 for node in nodes: 181 for arg in node.args: 182 if isinstance(arg, torch.fx.Node) and arg not in nodes: 183 input_nodes.add(arg) 184 185 if node.op == "get_attr": 186 params.add(node) 187 188 for user in node.users.keys(): 189 if user not in nodes: 190 output_nodes.add(node) 191 192 return SourcePartition( 193 nodes, 194 module_type, 195 list(input_nodes), 196 list(output_nodes), 197 list(params), # type: ignore[arg-type] 198 ) 199 200 ret: Dict[Type[Any], List[SourcePartition]] = {} 201 202 for k, v in modules.items(): 203 ret[k] = [make_partition(partition, k) for partition in v.values()] 204 205 return ret 206 207 208def _partitions_sequential(partitions: Tuple[SourcePartition]) -> bool: 209 prev_partition = None 210 for partition in partitions: 211 if prev_partition is not None and not check_subgraphs_connected( 212 prev_partition, partition 213 ): 214 return False 215 prev_partition = partition 216 return True 217 218 219def find_sequential_partitions_aten( 220 gm: torch.fx.GraphModule, 221 partition_types: List[Any], 222): 223 typed_partitions: OrderedDict[Any, List[SourcePartition]] = OrderedDict() 224 for partition_type in partition_types: 225 partitions = get_aten_node_target_partitions(gm.graph, [partition_type]) 226 typed_partitions[partition_type] = list( 227 itertools.chain.from_iterable(partitions.values()) 228 ) 229 230 typed_partitions_list = list(typed_partitions.values()) 231 fusion_candidates = itertools.product(*typed_partitions_list) 232 fused_partitions = [] 233 for candidate in fusion_candidates: 234 if _partitions_sequential(candidate): 235 fused_partitions.append(candidate) 236 return fused_partitions 237