1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import Any, Dict, List, Tuple 10 11import torch 12from executorch.backends.cadence.aot.quantizer.patterns import ( 13 AddmmPattern, 14 BmmPattern, 15 Conv1dPattern, 16 Conv2dPattern, 17 LayerNormPattern, 18 LinearPattern, 19 MatmulPattern, 20 ReluPattern0, 21 ReluPattern1, 22) 23from executorch.backends.cadence.aot.quantizer.utils import ( 24 create_zero_bias_int32, 25 find_sequential_partitions_aten, 26 get_conv_args, 27 quantize_tensor_multiplier, 28) 29from executorch.exir.pass_base import ExportPass 30from torch import fx 31from torch.fx import GraphModule 32from torch.fx.passes.infra.pass_base import PassResult 33from torch.fx.passes.utils.fuser_utils import legalize_graph 34 35 36# Use this to avoid pyre errors 37# pyre-ignore[33]: `_ModelInputsType` cannot alias to `Any`. 38ArgsType = Any 39 40# Use this part for patterns with multiple aten ops 41ReluPatterns = (ReluPattern0, ReluPattern1) 42 43 44# Helper function to get the args and kwargs for the linear replacement op 45def get_args_and_kwargs_linear( 46 graph_module: GraphModule, 47 inputs_inputs: List[fx.Node], 48 dequants_inputs: List[fx.Node], 49 weights_inputs: List[fx.Node], 50 dequants_weights: List[fx.Node], 51 bias_inputs: List[fx.Node], 52 quant_node: fx.Node, 53) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: 54 """ 55 Returns the args and kwargs for the linear replacement op. 56 """ 57 weight_scale = dequants_weights[0].args[1] 58 # pyre-fixme[58]: Unsupported operand types 59 bias_scale = dequants_inputs[0].args[1] * weight_scale 60 requantize_scale = bias_scale / quant_node.args[1] 61 requantize_scale_t = torch.tensor([requantize_scale]) 62 63 (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) 64 65 # If bias is not available, create a bias tensor with the shape of weight[0] 66 if not bias_inputs: 67 weight_node = dequants_weights[0].args[0] 68 assert isinstance(weight_node, fx.Node) 69 bias = create_zero_bias_int32(graph_module, weight_node, bias_scale) 70 else: 71 bias = bias_inputs[0] 72 73 # Create single element tensors for weight_zero_point, out_multiplier, out_shift. 74 # Note that the function expects int32_t, when it would default to int64_t, so 75 # we explicitly require that type. 76 weight_zero_point_ = graph_module.graph.call_function( 77 torch.ops.aten.full.default, 78 ([1], dequants_weights[0].args[2]), 79 {"dtype": torch.int32}, 80 ) 81 out_multiplier_ = graph_module.graph.call_function( 82 torch.ops.aten.full.default, 83 ([1], out_multiplier[0].item()), 84 {"dtype": torch.int32}, 85 ) 86 out_shift_ = graph_module.graph.call_function( 87 torch.ops.aten.full.default, 88 ([1], out_shift[0].item()), 89 {"dtype": torch.int32}, 90 ) 91 92 args = tuple(inputs_inputs + weights_inputs + [bias]) 93 kwargs = { 94 "src_zero_point": dequants_inputs[0].args[2], 95 "weight_zero_point": weight_zero_point_, 96 "out_multiplier": out_multiplier_, 97 "out_shift": out_shift_, 98 "out_zero_point": quant_node.args[2], 99 "offset": None, 100 } 101 return args, kwargs 102 103 104# Helper function to get the args and kwargs for the layer norm replacement op 105def get_args_and_kwargs_layer_norm( 106 graph_module: GraphModule, 107 inputs_inputs: List[fx.Node], 108 dequants_inputs: List[fx.Node], 109 other_inputs: List[fx.Node], 110 quant_node: fx.Node, 111) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: 112 """ 113 Returns the args and kwargs for the layer norm replacement op. 114 """ 115 # Check if the input is per-channel quantized 116 # TODO(matthiascremon): add proper support and testing for per-channel quantization 117 assert isinstance(dequants_inputs[0].args[1], float) and isinstance( 118 dequants_inputs[0].args[2], int 119 ), "per-channel quantization is not supported for layer norm, both scale and zero_point should be scalars" 120 121 # Make the scale and zero_point tensors 122 scale_tensor = graph_module.graph.call_function( 123 torch.ops.aten.full.default, 124 ( 125 [1], 126 dequants_inputs[0].args[1], 127 ), 128 {"dtype": torch.float32}, 129 ) 130 zero_point_tensor = graph_module.graph.call_function( 131 torch.ops.aten.full.default, 132 ( 133 [1], 134 dequants_inputs[0].args[2], 135 ), 136 {"dtype": torch.int32}, 137 ) 138 139 weight = other_inputs[1] if len(other_inputs) > 1 else None 140 141 if not weight: 142 weight = graph_module.graph.call_function( 143 torch.ops.aten.full.default, 144 ( 145 other_inputs[0], 146 1, 147 ), 148 {"dtype": torch.float32}, 149 ) 150 151 bias = other_inputs[2] if len(other_inputs) > 2 else None 152 153 if not bias: 154 bias = graph_module.graph.call_function( 155 torch.ops.aten.full.default, 156 ( 157 other_inputs[0], 158 0, 159 ), 160 {"dtype": torch.float32}, 161 ) 162 163 # Make the args and kwargs for the replacement op 164 args = tuple(inputs_inputs + [scale_tensor] + [zero_point_tensor]) 165 kwargs = { 166 "normalized_shape": other_inputs[0], 167 "weight": weight, 168 "bias": bias, 169 "eps": 1e-05, 170 "output_scale": quant_node.args[1], 171 "output_zero_point": quant_node.args[2], 172 } 173 return args, kwargs 174 175 176def get_args_and_kwargs_matmul( 177 inputs_inputs: List[fx.Node], 178 dequants_inputs: List[fx.Node], 179 quant_node: fx.Node, 180) -> Tuple[Tuple[ArgsType, ...], Dict[str, ArgsType]]: 181 requantize_scale = ( 182 # pyre-ignore[58]: Unsupported operand 183 dequants_inputs[0].args[1] 184 * dequants_inputs[1].args[1] 185 ) / quant_node.args[1] 186 requantize_scale_t = torch.tensor([requantize_scale]) 187 188 (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) 189 190 args = ( 191 inputs_inputs[0], 192 dequants_inputs[0].args[2], 193 inputs_inputs[1], 194 dequants_inputs[1].args[2], 195 None, 196 ) 197 198 kwargs = { 199 "out_multiplier": out_multiplier[0].item(), 200 "out_shift": out_shift[0].item(), 201 "out_zero_point": quant_node.args[2], 202 "transposed": False, 203 } 204 return args, kwargs 205 206 207def get_args_and_kwargs_conv( 208 graph_module: GraphModule, 209 inputs_inputs: List[fx.Node], 210 dequants_inputs: List[fx.Node], 211 weights_inputs: List[fx.Node], 212 dequants_weights: List[fx.Node], 213 bias_inputs: List[fx.Node], 214 quant_node: fx.Node, 215 op_node: fx.Node, 216) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: 217 weight_scale = dequants_weights[0].args[1] 218 weight_zero_point = dequants_weights[0].args[2] 219 # pyre-fixme[58]: Unsupported operand types 220 bias_scale = dequants_inputs[0].args[1] * weight_scale 221 stride = [1, 1] if len(op_node.args) < 4 else get_conv_args(op_node.args[3], 1) 222 padding = [0, 0] if len(op_node.args) < 5 else get_conv_args(op_node.args[4], 0) 223 dilation = [1, 1] if len(op_node.args) < 6 else get_conv_args(op_node.args[5], 1) 224 groups = 1 if len(op_node.args) < 7 else op_node.args[6] 225 226 # If bias is not available, create a bias tensor with the shape of weight[0] 227 if not bias_inputs: 228 weight_node = dequants_weights[0].args[0] 229 assert isinstance(weight_node, fx.Node) 230 bias = create_zero_bias_int32(graph_module, weight_node, bias_scale) 231 else: 232 bias = bias_inputs[0] 233 234 # Compute the out multiplier and out shift. They are used when the conv op is 235 # replaced by quantized linear, we compute them a priori for simplicity but 236 # may revisit the decision. 237 requantize_scale = bias_scale / quant_node.args[1] 238 requantize_scale_t = torch.tensor([requantize_scale]) 239 240 (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) 241 242 out_multiplier_ = graph_module.graph.call_function( 243 torch.ops.aten.full.default, 244 ([1], out_multiplier[0].item()), 245 {"dtype": torch.int32}, 246 ) 247 out_shift_ = graph_module.graph.call_function( 248 torch.ops.aten.full.default, 249 ([1], out_shift[0].item()), 250 {"dtype": torch.int32}, 251 ) 252 253 # Create a single element tensor for the weight zero point 254 weight_zero_point_tensor = graph_module.graph.call_function( 255 torch.ops.aten.full.default, 256 ([1], weight_zero_point), 257 {"dtype": torch.int32}, 258 ) 259 260 # Create a single element tensor for the bias scale 261 bias_scale_tensor = graph_module.graph.call_function( 262 torch.ops.aten.full.default, 263 ([1], bias_scale), 264 {"dtype": torch.float32}, 265 ) 266 267 # Make the args and kwargs for the replacement op 268 args = tuple(inputs_inputs + weights_inputs + [bias]) 269 kwargs = { 270 "stride": stride, 271 "padding": padding, 272 "dilation": dilation, 273 "groups": groups, 274 "input_zero_point": dequants_inputs[0].args[2], 275 "weight_zero_point": weight_zero_point_tensor, 276 "bias_scale": bias_scale_tensor, 277 "out_scale": quant_node.args[1], 278 "out_zero_point": quant_node.args[2], 279 "out_multiplier": out_multiplier_, 280 "out_shift": out_shift_, 281 "channel_last": False, 282 } 283 return args, kwargs 284 285 286def get_args_and_kwargs_relu( 287 graph_module: GraphModule, 288 inputs_inputs: List[fx.Node], 289 dequants_inputs: List[fx.Node], 290 quant_node: fx.Node, 291) -> Tuple[Tuple[ArgsType], Dict[str, ArgsType]]: 292 input_scale = dequants_inputs[0].args[1] 293 # pyre-fixme[58]: Unsupported operand types 294 requantize_scale = input_scale / quant_node.args[1] 295 requantize_scale_t = torch.tensor([requantize_scale]) 296 297 (out_multiplier, out_shift) = quantize_tensor_multiplier(requantize_scale_t) 298 299 # Make the args and kwargs for the replacement op 300 args = tuple(inputs_inputs) 301 302 X_zero_point = graph_module.graph.call_function( 303 torch.ops.aten.full.default, 304 ([1], dequants_inputs[0].args[2]), 305 {"dtype": torch.int32}, 306 ) 307 out_multiplier_ = graph_module.graph.call_function( 308 torch.ops.aten.full.default, 309 ([1], out_multiplier[0].item()), 310 {"dtype": torch.int32}, 311 ) 312 out_shift_ = graph_module.graph.call_function( 313 torch.ops.aten.full.default, 314 ([1], out_shift[0].item()), 315 {"dtype": torch.int32}, 316 ) 317 318 kwargs = { 319 "X_zero_point": X_zero_point, 320 "out_zero_point": quant_node.args[2], 321 "out_multiplier": out_multiplier_, 322 "out_shift": out_shift_, 323 } 324 return args, kwargs 325 326 327class QuantFusion(ExportPass): 328 # pyre-ignore[2]: Parameter `patterns` has no type specified 329 def __init__(self, patterns) -> None: 330 super().__init__() 331 # pyre-ignore[4]: Parameter `patterns` of class `QuantFusion` has no type specified 332 self.patterns = patterns 333 334 def call(self, graph_module: fx.GraphModule) -> PassResult: # noqa: C901 335 for pattern in self.patterns: 336 fused_partitions = find_sequential_partitions_aten( 337 graph_module, 338 pattern.partition_types(), 339 ) 340 for fused_partition in fused_partitions: 341 anchors = pattern.get_anchors(graph_module, fused_partition) 342 if not anchors: 343 continue 344 if any(self.is_fused(p.nodes) for p in fused_partition): 345 continue 346 347 for p in fused_partition: 348 self.mark_fused(p.nodes) 349 350 dequants_inputs = [] 351 for node, idx in anchors.inputs: 352 if ( 353 node.args[idx].target 354 == torch.ops.quantized_decomposed.dequantize_per_tensor.default 355 ): 356 dequants_inputs.append(node.args[idx]) 357 dequants_weights = [] 358 for node, idx in anchors.weights: 359 if ( 360 node.args[idx].target 361 == torch.ops.quantized_decomposed.dequantize_per_tensor.default 362 ): 363 dequants_weights.append(node.args[idx]) 364 dequants_biases = [] 365 for node, idx, *_spec in anchors.biases: 366 if ( 367 node.args[idx].target 368 == torch.ops.quantized_decomposed.dequantize_per_tensor.default 369 ): 370 dequants_biases.append(node.args[idx]) 371 372 inputs_inputs = [node.args[0] for node in dequants_inputs] 373 weights_inputs = [node.args[0] for node in dequants_weights] 374 bias_inputs = [node.args[0] for node in dequants_biases] 375 other_inputs = [node.args[idx] for node, idx in anchors.others] 376 377 # The node is the first index of the list and first of the tuple 378 op_node = anchors.output[0][0] 379 380 assert len(op_node.users) == 1 381 quant_node = list(op_node.users.keys())[0] 382 383 with graph_module.graph.inserting_after(op_node): 384 args = tuple( 385 inputs_inputs + weights_inputs + other_inputs + bias_inputs 386 ) 387 kwargs = {} 388 if isinstance(pattern, (Conv1dPattern, Conv2dPattern)): 389 args, kwargs = get_args_and_kwargs_conv( 390 graph_module, 391 inputs_inputs, 392 dequants_inputs, 393 weights_inputs, 394 dequants_weights, 395 bias_inputs, 396 quant_node, 397 op_node, 398 ) 399 elif isinstance(pattern, LinearPattern): 400 args, kwargs = get_args_and_kwargs_linear( 401 graph_module, 402 inputs_inputs, 403 dequants_inputs, 404 weights_inputs, 405 dequants_weights, 406 bias_inputs, 407 quant_node, 408 ) 409 elif isinstance(pattern, LayerNormPattern): 410 args, kwargs = get_args_and_kwargs_layer_norm( 411 graph_module, 412 inputs_inputs, 413 dequants_inputs, 414 other_inputs, 415 quant_node, 416 ) 417 elif isinstance(pattern, (BmmPattern, MatmulPattern)): 418 args, kwargs = get_args_and_kwargs_matmul( 419 inputs_inputs, 420 dequants_inputs, 421 quant_node, 422 ) 423 elif isinstance(pattern, AddmmPattern): 424 # Transpose the weight tensor 425 transposed_weights = graph_module.graph.call_function( 426 torch.ops.aten.transpose.int, 427 (weights_inputs[0], 0, 1), 428 ) 429 # Call linear with transposed weight 430 args, kwargs = get_args_and_kwargs_linear( 431 graph_module, 432 inputs_inputs, 433 dequants_inputs, 434 [transposed_weights], 435 dequants_weights, 436 bias_inputs, 437 quant_node, 438 ) 439 elif isinstance(pattern, ReluPatterns): 440 args, kwargs = get_args_and_kwargs_relu( 441 graph_module, 442 inputs_inputs, 443 dequants_inputs, 444 quant_node, 445 ) 446 fused = graph_module.graph.call_function( 447 pattern.replacement_op(), 448 args, 449 kwargs, 450 ) 451 fused.meta = quant_node.meta 452 quant_node.replace_all_uses_with(fused) 453 454 legalize_graph(graph_module) 455 graph_module.graph.eliminate_dead_code() 456 # pyre-fixme[7]: Incompatible return type 457 graph_module.recompile() 458 459 @classmethod 460 # pyre-ignore[2]: Parameter `nodes` has no type specified 461 def is_fused(cls, nodes) -> bool: 462 return any(cls.__qualname__ in n.meta for n in nodes) 463 464 @classmethod 465 # pyre-ignore[2]: Parameter `nodes` has no type specified 466 def mark_fused(cls, nodes) -> bool: 467 for n in nodes: 468 # pyre-fixme[7]: Incompatible return type 469 n.meta["QuantFusion"] = True 470