1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3 4# This file contains all the functions that fuse ops in the fx graph. 5 6import logging 7import math 8import operator 9from collections import deque 10from numbers import Number 11from typing import cast, Sequence 12 13import torch 14import torch.fx 15from executorch.backends.cadence.aot.compiler_utils import ( 16 broadcastable, 17 get_cascaded_ops, 18 get_permuted_dims, 19 get_scale, 20 get_shape, 21 get_tensor_from_attr, 22 get_transposed_dims, 23 get_zero_point, 24) 25from executorch.backends.cadence.aot.pass_utils import ( 26 CadencePassAttribute, 27 register_cadence_pass, 28) 29from executorch.backends.cadence.aot.utils import get_edge_overload_packet 30from executorch.exir.dialects._ops import ops as exir_ops 31from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket 32from executorch.exir.pass_base import ExportPass, NodeMetadata, PassResult, ProxyValue 33from executorch.exir.passes import dead_code_elimination_pass 34from executorch.exir.passes.spec_prop_pass import SpecPropPass 35from torch.fx.node import Argument 36from torch.nn.utils.fusion import fuse_conv_bn_weights 37 38 39@register_cadence_pass(CadencePassAttribute(opt_level=1)) 40class FuseMMWithAdd(ExportPass): 41 # Return true if the node is a view node. 42 43 def is_view_node(self, node: torch.fx.Node): 44 return node.target == exir_ops.edge.aten.view_copy.default 45 46 def fuse_mm_with_add(self, graph_module: torch.fx.GraphModule): 47 """ 48 Given a graph of the form: 49 X = aten.mm(A, B) 50 Y = aten.add(X, C) 51 Fuse X and Y into a single addmm node, after making sure that we can 52 broadcast C into X. 53 There could be view node that takes a view of X, and feeds that 54 to the aten.add node: 55 X = aten.mm(A, B) 56 Y = X.view() 57 Z = aten.add(Y, C) 58 Handle this case as well. There are a few conditions for the 59 optimization to be valid: 60 1. There should be a single user of the mm node, otherwise we cannot 61 remove it. 62 2. There should be a single user of the add node, otherwise we cannot 63 fuse it with mm. 64 """ 65 graph = graph_module.graph 66 for node in graph.nodes: 67 # We want to discover a chain of mm -> add, or mm -> view -> add. 68 # Only proceed if the current node is an mm node, and has only one 69 # user/successor. 70 if node.target != exir_ops.edge.aten.mm.default or len(node.users) != 1: 71 continue 72 73 # Our addmm implementation computes (mat1 * mat2 + bias). So the 74 # addmm node in the graph should have three args. We collectively 75 # term mat1 and mat2 as mm_arg since they are the args of mm node, 76 # and bias as bias_arg. 77 # Since we already have discovered the mm node, we can get mat1 and 78 # mat2 by iterating over its args. So the current node is mm_arg. 79 # bias_arg can be found once we discover the add op that consumes 80 # the output of this mm node. Our next step is to find the add op. 81 mm_arg = node 82 user = list(node.users.keys())[0] 83 # intermediate_view is True when the fusion case is mm -> view -> add 84 intermediate_view = False 85 # Check if the single user of the mm node is a view op. If so, our 86 # graph could potentially have mm -> view -> add. We need to skip 87 # the view op, and check if its successor is the add op. One condition 88 # we need to verify is that the view op must have only a single user 89 # (the add op). 90 if self.is_view_node(user) and len(user.users) == 1: 91 # We want to maintain two invariants: 92 # (1) 'user' is a potential add op that will get fused with the 93 # mm node; 94 # (2) 'node' is the single predecessor of 'user' that is either 95 # the mm node, or the current view node; 96 # To maintain the invariant, we must mark this view op as 'node', 97 # and its single successor as 'user'. 98 intermediate_view = True 99 node = user 100 user = list(node.users.keys())[0] 101 102 # Thanks to the invariant, we can now simply check if 'user' is an 103 # add op. We also want to ensure that the add op has only one user, 104 # otherwise we will get not be able to eliminate add op post fusion. 105 if user.target != exir_ops.edge.aten.add.Tensor or len(user.users) != 1: 106 continue 107 108 # At this point, we have found an mm and an add node that we can 109 # fuse together. One arg of the add op is 'node' (thanks to the 110 # invariant). Find the other arg, and tag it as bias_arg. 111 assert len(user.args) == 2 112 bias_arg = user.args[1] if user.args[0] == node else user.args[0] 113 114 # As a last check, make sure that we can broadcast the bias tensor 115 # to the output of mm. 116 mm_arg_shape = get_shape(graph_module, mm_arg) 117 bias_arg_shape = get_shape(graph_module, bias_arg) 118 if ( 119 mm_arg_shape is None 120 or bias_arg_shape is None 121 or not broadcastable(mm_arg_shape, bias_arg_shape) 122 ): 123 continue 124 125 # Create a new addmm node, and insert it before add node. DCE should 126 # take care of removing the dead mm and/or view node. Based on the 127 # invariant, add node corresponds to 'user'. 128 with graph.inserting_before(user): 129 addmm_node = graph.call_function( 130 exir_ops.edge.aten.addmm.default, 131 args=(bias_arg, mm_arg.args[0], mm_arg.args[1]), 132 ) 133 # Replace all the uses of add node with addmm node, and remove add 134 # node from the graph. 135 user.replace_all_uses_with(addmm_node) 136 graph.erase_node(user) 137 138 # As a finishing step, we want to ensure that the output of addmm is 139 # in the expected shape. For example, Let us assume the following 140 # input, where A, B are (4, 4) sized tensors, and C is (1, 4) sized 141 # tensor. 142 # T1 = torch.mm(A, B) 143 # T2 = T1.view((2, 2, 4)) 144 # return torch.add(T2, C) 145 # Here, the expectation is to get an output of size (2, 2, 4), which 146 # is the shape out of view node T2. However, the fused addmm will 147 # return an output of shape (4, 4). In a nutshell, we need to take 148 # care of the output shape when the following two conditions are met: 149 # 1. The fusion case is mm -> view -> add (i.e., intermediate_view 150 # is True) 151 # 2. The single successor of addmm is not a view op. 152 addmm_user = list(addmm_node.users.keys())[0] 153 if intermediate_view and not self.is_view_node(addmm_user): 154 # Create a view node that correctly reshapes the output of addmm 155 # (i.e., 'user') to match the output shape of the add node. 156 # Thanks to our invariant, we know that the correct shape is held 157 # by 'node', which points to the view op in mm -> view -> add chain. 158 # We create its copy, and insert it just before addmm_user. 159 with graph.inserting_before(addmm_user): 160 view_copy_node = graph_module.graph.node_copy(node) 161 # Any uses of addmm are replaced with this view_copy node. 162 addmm_node.replace_all_uses_with(view_copy_node) 163 # Now we massage the args of the view_copy node, so that it takes 164 # view of addmm node. 165 view_args = list(view_copy_node.args) 166 view_args[0] = addmm_node 167 view_copy_node.args = tuple(view_args) 168 169 graph_module.recompile() 170 171 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 172 # Compute the spec prop pass before we begin the fusion pipeline 173 result = SpecPropPass()(graph_module) 174 assert result is not None 175 self.fuse_mm_with_add(result.graph_module) 176 result = super().call(result.graph_module) 177 return result 178 179 180@register_cadence_pass(CadencePassAttribute(opt_level=1)) 181class FuseBatchNormWithConv(ExportPass): 182 """ 183 This pass fuses a conv op with batchnorm if the following two conditions 184 are met: 185 1. The only user of conv op should be batchnorm; 186 2. Only the first element from the batchnorm output tuple should be used 187 in the graph. 188 """ 189 190 def fuse_batch_norm_with_conv(self, graph_module: torch.fx.GraphModule) -> None: 191 graph = graph_module.graph 192 for conv in graph.nodes: 193 # We want to discover a chain of conv1d -> batch_norm. 194 # Only proceed if the current node is a conv1d node, and has a single 195 # user/successor. 196 if ( 197 conv.target != exir_ops.edge.aten.convolution.default 198 or len(conv.users) != 1 199 ): 200 continue 201 202 # The single user of conv op must be batch_norm. If not, bail. 203 bn = list(conv.users.keys())[0] 204 if bn.target != exir_ops.edge.aten.native_batch_norm.default: 205 continue 206 207 # All the users of batchnorm node must be getitem ops. batchnorm 208 # returns a 3-element tuple. Each user must only access the first 209 # element of the tuple. 210 if [ 211 (user.target == operator.getitem and user.args[1] == 0) 212 for user in bn.users 213 ].count(False): 214 continue 215 216 # Check that the weights for conv1d and batchnorm are both params 217 if [node.op == "get_attr" for node in {conv.args[1], bn.args[1]}].count( 218 False 219 ): 220 continue 221 222 # Get the parameters from conv op 223 assert len(conv.args) == 9 224 conv_weight = get_tensor_from_attr(graph_module, conv.args[1]) 225 assert isinstance(conv_weight, torch.Tensor) 226 conv_bias = get_tensor_from_attr(graph_module, conv.args[2]) 227 transpose = conv.args[6] 228 229 # Get the parameters from the batchnorm op 230 assert len(bn.args) == 8 231 bn_weight = get_tensor_from_attr(graph_module, bn.args[1]) 232 bn_bias = get_tensor_from_attr(graph_module, bn.args[2]) 233 running_mean = get_tensor_from_attr(graph_module, bn.args[3]) 234 assert isinstance(running_mean, torch.Tensor) 235 running_var = get_tensor_from_attr(graph_module, bn.args[4]) 236 assert isinstance(running_var, torch.Tensor) 237 eps = bn.args[-1] 238 239 # Compute the updated weight and bias after fusing conv op 240 # with batchnorm op. 241 fused_weight, fused_bias = fuse_conv_bn_weights( 242 conv_weight, 243 conv_bias, 244 running_mean, 245 running_var, 246 eps, 247 bn_weight, 248 bn_bias, 249 transpose, 250 ) 251 252 # Modify the graph by updating the weight and bias of conv op 253 # with the fused weight and bias params, and replacing all the users 254 # of getitem(batchnorm) with the conv op. 255 with graph.inserting_before(conv): 256 fused_weight_name = f"_fused_with_bn_weight_{self.counter}" 257 graph_module.register_parameter(fused_weight_name, fused_weight) 258 fused_weight_node = graph.get_attr(fused_weight_name) 259 fused_bias_name = f"_fused_with_bn_bias_{self.counter}" 260 graph_module.register_parameter(fused_bias_name, fused_bias) 261 fused_bias_node = graph.get_attr(fused_bias_name) 262 263 # Update the weight and bias of conv op 264 conv_args = list(conv.args) 265 conv_args[1] = fused_weight_node 266 conv_args[2] = fused_bias_node 267 conv.args = tuple(conv_args) 268 # Remove any use of batchnorm from the graph 269 for user in bn.users: 270 assert user.target == operator.getitem 271 user.replace_all_uses_with(conv) 272 self.counter += 1 273 274 graph_module.recompile() 275 276 def __init__(self): 277 super().__init__() 278 self.counter = 0 279 280 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 281 self.fuse_batch_norm_with_conv(graph_module) 282 result = super().call(graph_module) 283 return result 284 285 286@register_cadence_pass(CadencePassAttribute(opt_level=1)) 287class FuseQuantizedBatchNormWithConv(ExportPass): 288 """ 289 This pass fuses a quantized::conv op with quantized::batchnorm if the 290 following two conditions are met: 291 1. The only user of quantized::conv op should be quantized::batchnorm; 292 2. The outputs of both ops are quantized with same scale and zero_point 293 """ 294 295 def fuse_quantized_batch_norm_with_conv( 296 self, graph_module: torch.fx.GraphModule 297 ) -> None: 298 graph = graph_module.graph 299 for conv in graph.nodes: 300 # We want to discover a chain of quantized::conv1d -> 301 # quantized::batch_norm. Only proceed if the current node is a 302 # quantized::conv node, and has a single user/successor. 303 if ( 304 conv.target 305 not in { 306 exir_ops.edge.quantized.conv1d.default, 307 exir_ops.edge.quantized.conv2d.new, 308 } 309 or len(conv.users) != 1 310 ): 311 continue 312 313 # The single user of conv op must be batch_norm. If not, bail. 314 bn = list(conv.users.keys())[0] 315 if bn.target not in { 316 exir_ops.edge.quantized.batch_norm1d.default, 317 exir_ops.edge.quantized.batch_norm2d.default, 318 }: 319 continue 320 321 # The outputs of conv and bn must both have same scale and zero_point 322 if not math.isclose( 323 conv.args[-2], bn.args[-2], rel_tol=1e-05, abs_tol=1e-05 324 ): 325 continue 326 if conv.args[-1] != bn.args[-1]: 327 continue 328 329 # The weight and bias of quantized::conv op are packed in the second 330 # arg. Unpack them. 331 assert conv.args[1].op == "get_attr" 332 packed_args = getattr(graph_module, conv.args[1].target) 333 conv_weight_tensor, conv_bias_tensor = packed_args.unpack() 334 # Assert that we have discovered the conv op's weight and bias tensors 335 assert isinstance(conv_weight_tensor, torch.Tensor) 336 assert conv_bias_tensor is None or isinstance( 337 conv_bias_tensor, torch.Tensor 338 ) 339 340 # Get the scale, zero_point, and dtype of convolution weight 341 assert conv_weight_tensor.is_quantized 342 per_tensor_quantization = ( 343 conv_weight_tensor.qscheme() == torch.per_tensor_affine 344 ) 345 weight_dtype = conv_weight_tensor.dtype 346 weight_scale = get_scale(conv_weight_tensor) 347 weight_zero_point = get_zero_point(conv_weight_tensor, reduce=False) 348 weight_axis = ( 349 0 350 if per_tensor_quantization 351 else conv_weight_tensor.q_per_channel_axis() 352 ) 353 # Dequantize the convolution weight 354 conv_weight_tensor = conv_weight_tensor.dequantize() 355 356 # Get the parameters from the batchnorm op 357 assert len(bn.args) == 8 358 (bn_weight, bn_bias, running_mean, running_var, eps) = bn.args[1:6] 359 # Get the tensors from the batchnorm args 360 bn_weight_tensor = get_tensor_from_attr(graph_module, bn_weight) 361 bn_bias_tensor = get_tensor_from_attr(graph_module, bn_bias) 362 running_mean_tensor = get_tensor_from_attr(graph_module, running_mean) 363 running_var_tensor = get_tensor_from_attr(graph_module, running_var) 364 365 # Assert that we have discovered the batch_norm op's tensors 366 assert bn_weight_tensor is None or isinstance( 367 bn_weight_tensor, torch.Tensor 368 ) 369 assert bn_bias_tensor is None or isinstance(bn_bias_tensor, torch.Tensor) 370 assert isinstance(running_mean_tensor, torch.Tensor) 371 assert isinstance(running_var_tensor, torch.Tensor) 372 373 # Get the fused weights and bias 374 fused_weight, fused_bias = fuse_conv_bn_weights( 375 conv_weight_tensor, 376 conv_bias_tensor, 377 running_mean_tensor, 378 running_var_tensor, 379 eps, 380 bn_weight_tensor, 381 bn_bias_tensor, 382 transpose=False, 383 ) 384 385 # Requantize the fused weight with the scale and zero point of the 386 # quantized::conv's weight 387 if per_tensor_quantization: 388 fused_weight = torch.quantize_per_tensor( 389 fused_weight, 390 weight_scale.item(), 391 cast(int, weight_zero_point.item()), 392 weight_dtype, 393 ) 394 else: 395 fused_weight = torch.quantize_per_channel( 396 fused_weight, 397 weight_scale, 398 weight_zero_point, 399 weight_axis, 400 weight_dtype, 401 ) 402 403 # Now that we have the fused weight and bias, pack them for the 404 # quantized::conv. 405 stride = packed_args.stride() 406 padding = packed_args.padding() 407 dilation = packed_args.dilation() 408 groups = packed_args.groups() 409 args = (fused_weight, fused_bias, stride, padding, dilation, groups) 410 packed_args = ( 411 exir_ops.edge.quantized.conv1d_prepack(*args) 412 if conv.target == exir_ops.edge.quantized.conv1d.default 413 else exir_ops.edge.quantized.conv2d_prepack(*args) 414 ) 415 416 # Modify the graph by updating the weight and bias of conv op 417 # with the fused weight and bias params, and replacing all the users 418 # of batchnorm with the conv op. 419 conv_args = list(conv.args) 420 conv_args[1] = packed_args 421 conv.args = tuple(conv_args) 422 bn.replace_all_uses_with(conv) 423 graph.erase_node(bn) 424 self.counter += 1 425 426 # Note: there is a quantized.conv2d.new operator in the resulting graph 427 # that takes a torch.classes.quantized.Conv2dPackedParamsBase as one of the input 428 # this prevents us to directly call graph_module.recompile(). 429 graph_module._code = graph_module._graph.python_code(root_module="self").src 430 431 def __init__(self): 432 super().__init__() 433 self.counter = 0 434 435 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 436 self.fuse_quantized_batch_norm_with_conv(graph_module) 437 result = super().call(graph_module) 438 return result 439 440 441@register_cadence_pass(CadencePassAttribute(opt_level=1)) 442class FuseCascadedTransposeOrPermuteOps(ExportPass): 443 """ 444 Fuse a cascaded chain of transpose and permute ops 445 """ 446 447 transpose_or_permute_target = { 448 exir_ops.edge.aten.transpose_copy.int, 449 exir_ops.edge.aten.permute_copy.default, 450 } 451 452 # Find a chain of transpose or permute ops, and fuse them into a single permute op. 453 454 def fuse_cascaded_transpose_or_permute_ops( 455 self, graph_module: torch.fx.GraphModule 456 ): 457 graph = graph_module.graph 458 for node in graph.nodes: 459 # We are only interested in permute/transpose ops 460 if node.target not in self.transpose_or_permute_target: 461 continue 462 # Get the cascaded chain of transpose/permute ops starting at node 463 cascaded_transpose_or_permute_ops = get_cascaded_ops( 464 [node], self.transpose_or_permute_target 465 ) 466 # The chain must have more than 1 node 467 if len(cascaded_transpose_or_permute_ops) == 1: 468 continue 469 470 out_shape = get_shape(graph_module, node) 471 assert out_shape is not None 472 out_dims = len(out_shape) 473 # This is the trivial dimension order 474 dims = list(range(out_dims)) 475 # Compute the effect of the chain on dims 476 for tp in cascaded_transpose_or_permute_ops: 477 dims = ( 478 get_transposed_dims(tp, dims) 479 if tp.target == exir_ops.edge.aten.transpose_copy.int 480 else get_permuted_dims(tp, dims) 481 ) 482 483 # In case the permute chain cancelled each other, the final dims will 484 # be the same as the initial order. In that case, the chain was nop. 485 # Otherwise create a new permute op that encompasses the effect of the 486 # chain. 487 if dims == list(range(out_dims)): 488 cascaded_transpose_or_permute_ops[-1].replace_all_uses_with( 489 node.args[0] 490 ) 491 else: 492 with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]): 493 new_permute = graph.call_function( 494 exir_ops.edge.aten.permute_copy.default, 495 args=(node.args[0], dims), 496 ) 497 cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute) 498 499 # Now erase the chain 500 for tp in reversed(cascaded_transpose_or_permute_ops): 501 graph.erase_node(tp) 502 503 graph_module.recompile() 504 505 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 506 self.fuse_cascaded_transpose_or_permute_ops(graph_module) 507 result = super().call(graph_module) 508 return result 509 510 511@register_cadence_pass(CadencePassAttribute(opt_level=1)) 512class FuseCascadedViewOps(ExportPass): 513 """ 514 Fuse a cascaded chain of view ops 515 """ 516 517 # Find a chain of view ops, and fuse them into a single permute op. 518 519 def fuse_cascaded_view_ops(self, graph_module: torch.fx.GraphModule): 520 graph = graph_module.graph 521 for node in graph.nodes: 522 # We are only interested in view ops 523 if node.target != exir_ops.edge.aten.view_copy.default: 524 continue 525 526 # Get the cascaded chain of view ops starting at node 527 cascaded_view_ops = get_cascaded_ops( 528 [node], [exir_ops.edge.aten.view_copy.default] 529 ) 530 # The chain must have more than 1 node 531 if len(cascaded_view_ops) == 1: 532 continue 533 534 last_view_node = cascaded_view_ops[-1] 535 with graph.inserting_before(last_view_node): 536 new_view = graph.call_function( 537 exir_ops.edge.aten.view_copy.default, 538 args=(node.args[0], last_view_node.args[1]), 539 ) 540 last_view_node.replace_all_uses_with(new_view) 541 542 # Now erase the chain 543 for v in reversed(cascaded_view_ops): 544 graph.erase_node(v) 545 546 graph_module.recompile() 547 548 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 549 self.fuse_cascaded_view_ops(graph_module) 550 dead_code_elimination_pass(graph_module) 551 result = super().call(graph_module) 552 return result 553 554 555class FuseOpPairsAcrossBranchesPass(ExportPass): 556 def check_ok_to_fuse( 557 self, 558 producer: torch.fx.Node, 559 consumers: list[torch.fx.Node], 560 ) -> bool: 561 # Always ok to replace / remove. 562 return True 563 564 def can_fuse_for_chain( 565 self, 566 producer: torch.fx.Node, 567 consumer: torch.fx.Node, 568 consumer_op_packets: set[EdgeOpOverloadPacket], 569 ) -> bool: 570 """ 571 Returns true if producer and consumer can be fused for a single chain 572 (-> producer -> ops -> consumer ->) to (-> ops -> fused_op) 573 """ 574 if ( 575 isinstance(consumer.target, EdgeOpOverload) 576 and get_edge_overload_packet(consumer.target) in consumer_op_packets 577 ): 578 return True 579 return False 580 581 def get_fuse_candidates( 582 self, 583 producer: torch.fx.Node, 584 consumer_op_packets: set[EdgeOpOverloadPacket], 585 bypass_ops: set[EdgeOpOverload], 586 ) -> list[torch.fx.Node]: 587 # Start by iterating over all the users of this node, and check 588 # if they are have their target in consumer_op_packets. 589 users = deque(producer.users.keys()) 590 # This holds the list of the user ops that directly (or transitively 591 # via view/slice) consume this producer_op_packets, and hence can be removed. 592 removal_candidates = [] 593 while users: 594 user = users.popleft() 595 596 # If the user is a bypass op, we bypass it, and examine 597 # its users instead for consumer_op_packets. 598 if user.target in bypass_ops: 599 users.extend(list(user.users.keys())) 600 elif self.can_fuse_for_chain(producer, user, consumer_op_packets): 601 removal_candidates.append(user) 602 else: 603 removal_candidates.clear() 604 break 605 return removal_candidates 606 607 def find_and_fuse( 608 self, 609 graph_module: torch.fx.GraphModule, 610 producer_op_packets: set[EdgeOpOverloadPacket], 611 consumer_op_packets: set[EdgeOpOverloadPacket], 612 bypass_ops: set[EdgeOpOverload], 613 ) -> None: 614 for node in graph_module.graph.nodes: 615 # We are only interested in ops that have overload target in 616 # producer_op. 617 if not ( 618 isinstance(node.target, EdgeOpOverload) 619 and get_edge_overload_packet(node.target) in producer_op_packets 620 ): 621 continue 622 623 removal_candidates = self.get_fuse_candidates( 624 node, consumer_op_packets, bypass_ops 625 ) 626 627 if len(removal_candidates) == 0: 628 # No candidates found. 629 continue 630 631 if not self.check_ok_to_fuse(node, removal_candidates): 632 # Not ok to remove quant-dequant pairs or replace with requantize. 633 continue 634 635 self.fuse(node, removal_candidates, graph_module) 636 637 graph_module.recompile() 638 639 def get_fused_node( 640 self, 641 producer: torch.fx.Node, 642 consumer: torch.fx.Node, 643 graph_module: torch.fx.GraphModule, 644 ) -> torch.fx.Node: 645 return consumer 646 647 def fuse( 648 self, 649 node: torch.fx.Node, 650 removal_candidates: list[torch.fx.Node], 651 graph_module: torch.fx.GraphModule, 652 ) -> None: 653 # Replace all the uses of the producer op with it's input. 654 node.replace_all_uses_with(cast(torch.fx.Node, node.args[0])) 655 graph_module.graph.erase_node(node) 656 657 # Iterate over all the removal candidates (quantize op users) and generate replacements. 658 for rnode in removal_candidates: 659 rnode.replace_all_uses_with(self.get_fused_node(node, rnode, graph_module)) 660 graph_module.graph.erase_node(rnode) 661 662 663@register_cadence_pass(CadencePassAttribute(opt_level=1)) 664class FuseQuantDequantToRequantizePass(FuseOpPairsAcrossBranchesPass): 665 """ 666 Fuse dequantize-quantize op pairs to a single requantize op. 667 For the special case where quant params match, this will remove 668 both dequant and quant ops. 669 """ 670 671 # A list of ops that can be bypassed when looking for a 672 # dequantize->quantize chain 673 bypass_ops: set[EdgeOpOverload] = { 674 exir_ops.edge.aten.slice_copy.Tensor, 675 exir_ops.edge.aten.view_copy.default, 676 exir_ops.edge.aten.clone.default, 677 exir_ops.edge.aten.transpose_copy.int, 678 exir_ops.edge.aten.permute_copy.default, 679 } 680 681 quantize_op_packets: set[EdgeOpOverloadPacket] = { 682 exir_ops.edge.cadence.quantize_per_tensor, 683 exir_ops.edge.quantized_decomposed.quantize_per_tensor, 684 } 685 dequantize_op_packets: set[EdgeOpOverloadPacket] = { 686 exir_ops.edge.cadence.dequantize_per_tensor, 687 exir_ops.edge.quantized_decomposed.dequantize_per_tensor, 688 } 689 690 def __init__( 691 self, allow_requantize: bool = True, force_quant_dequant_fusion: bool = False 692 ) -> None: 693 super().__init__() 694 self.allow_requantize: bool = allow_requantize 695 self.force_quant_dequant_fusion: bool = force_quant_dequant_fusion 696 697 def _pkg_name_match(self, node1: torch.fx.Node, node2: torch.fx.Node) -> bool: 698 # pyre-ignore[16]: Item `typing.Callable` has no attribute `_op` 699 return node1.target._op.namespace == node2.target._op.namespace 700 701 def can_fuse_for_chain( 702 self, 703 producer: torch.fx.Node, 704 consumer: torch.fx.Node, 705 consumer_op_packets: set[EdgeOpOverloadPacket], 706 ) -> bool: 707 return super().can_fuse_for_chain( 708 producer, consumer, consumer_op_packets 709 ) and self._pkg_name_match(producer, consumer) 710 711 def _create_requantize_node( 712 self, 713 in_tensor: torch.fx.Node, 714 in_scale: float, 715 in_zero_point: int, 716 out_scale: float, 717 out_zero_point: int, 718 out_dtype: torch.dtype, 719 graph: torch.fx.Graph, 720 ) -> torch.fx.Node: 721 in_scale_tensor = graph.call_function( 722 exir_ops.edge.aten.full.default, args=((1,), in_scale) 723 ) 724 in_zero_point_tensor = graph.call_function( 725 exir_ops.edge.aten.full.default, 726 args=((1,), in_zero_point), 727 kwargs={"dtype": torch.int32}, 728 ) 729 out_scale_tensor = graph.call_function( 730 exir_ops.edge.aten.full.default, args=((1,), out_scale) 731 ) 732 out_zero_point_tensor = graph.call_function( 733 exir_ops.edge.aten.full.default, 734 args=((1,), out_zero_point), 735 kwargs={"dtype": torch.int32}, 736 ) 737 # cadence::requantize(Tensor input, Tensor in_scale, Tensor in_zero_point, Tensor out_scale, Tensor out_zero_point, ScalarType out_dtype) -> Tensor Y 738 # TODO(hardiksharma): Add support for per-tensor requantize. 739 return graph.call_function( 740 exir_ops.edge.cadence.requantize.default, 741 args=( 742 in_tensor, 743 in_scale_tensor, 744 in_zero_point_tensor, 745 out_scale_tensor, 746 out_zero_point_tensor, 747 out_dtype, 748 ), 749 ) 750 751 def _quant_params_match(self, node1: torch.fx.Node, node2: torch.fx.Node) -> bool: 752 return node1.args[1:] == node2.args[1:] 753 754 def check_ok_to_fuse( 755 self, 756 producer: torch.fx.Node, 757 consumers: list[torch.fx.Node], 758 ) -> bool: 759 """Check if all node-user pairs are nops or are ok to replace with requant.""" 760 for rnode in consumers: 761 if self.allow_requantize or self._quant_params_match(producer, rnode): 762 # Cannot remove quant-dequant pair if quant params don't match and requantize 763 # is not allowed. 764 continue 765 return False 766 return True 767 768 def get_fused_node( 769 self, 770 producer: torch.fx.Node, 771 consumer: torch.fx.Node, 772 graph_module: torch.fx.GraphModule, 773 ) -> torch.fx.Node: 774 in_scale, in_zero_point = producer.args[1:3] 775 in_tensor, out_scale, out_zero_point, _, _, out_dtype = consumer.args 776 if in_scale == out_scale and in_zero_point == out_zero_point: 777 # If the quant params match, we can remove both dequantize-quantize ops. 778 return cast(torch.fx.Node, consumer.args[0]) 779 780 assert ( 781 self.allow_requantize 782 ), f"Found {producer=} {in_scale=} {in_zero_point=} | {consumer=} {out_scale=} {out_zero_point=}" 783 784 with graph_module.graph.inserting_before(consumer): 785 requantize_node = self._create_requantize_node( 786 in_tensor=cast(torch.fx.Node, consumer.args[0]), 787 in_scale=cast(float, in_scale), 788 in_zero_point=cast(int, in_zero_point), 789 out_scale=cast(float, out_scale), 790 out_zero_point=cast(int, out_zero_point), 791 out_dtype=cast(torch.dtype, out_dtype), 792 graph=graph_module.graph, 793 ) 794 return requantize_node 795 796 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 797 # Remove any dequantize op that has only quantize ops as its users. 798 self.find_and_fuse( 799 graph_module, 800 producer_op_packets=self.dequantize_op_packets, 801 consumer_op_packets=self.quantize_op_packets, 802 bypass_ops=self.bypass_ops, 803 ) 804 # Remove any quantize op that has only dequantze ops as its users. 805 self.find_and_fuse( 806 graph_module, 807 producer_op_packets=self.quantize_op_packets, 808 consumer_op_packets=self.dequantize_op_packets, 809 # Do not requantize for quantize-dequantize pairs as this is not guaranteed 810 # to be better for performance/memory. 811 # Only fuse if all users of quant are dequant. 812 bypass_ops=( 813 self.bypass_ops 814 if self.force_quant_dequant_fusion 815 else {exir_ops.edge.aten.view_copy.default} 816 ), 817 ) 818 result = super().call(graph_module) 819 return result 820 821 822@register_cadence_pass(CadencePassAttribute(opt_level=1)) 823class FuseMulIntoDequantPass(ExportPass): 824 """ 825 Looks for the pattern where atem.mul is multiplying the outputs of dequantize 826 and aten.full. If found, updates the dequant scale to reflect the multiplication 827 and removes the full and mul nodes. 828 """ 829 830 def attempt_fusion( 831 self, graph_module: torch.fx.GraphModule, node: torch.fx.Node 832 ) -> None: 833 if node.target != exir_ops.edge.aten.mul.Tensor: 834 return 835 836 # ensure that one of the args to mul is dequantize and the other is aten.full 837 dequant_nodes = [ 838 arg 839 for arg in node.args 840 if isinstance(arg, torch.fx.Node) 841 and isinstance(arg.target, EdgeOpOverload) 842 and get_edge_overload_packet(arg.target) 843 == exir_ops.edge.quantized_decomposed.dequantize_per_tensor 844 ] 845 multiplier_nodes = [ 846 arg 847 for arg in node.args 848 if isinstance(arg, torch.fx.Node) 849 and arg.target == exir_ops.edge.aten.full.default 850 ] 851 852 if len(dequant_nodes) != 1 or len(multiplier_nodes) != 1: 853 return 854 855 deq_node = dequant_nodes[0] 856 mplier_node = multiplier_nodes[0] 857 858 # ensure that dequant and full don't have any other users 859 if len(deq_node.users) > 1 or len(mplier_node.users) > 1: 860 return 861 862 new_deq_args = list(deq_node.args) 863 assert isinstance(deq_node.args[1], Number) 864 assert isinstance(mplier_node.args[1], Number) 865 # pyre-ignore[58]: Unsupported operand * 866 new_deq_args[1] = deq_node.args[1] * mplier_node.args[1] 867 868 logging.debug( 869 f"Fused {node} and {mplier_node} into {deq_node}. Updated scale from {deq_node.args[1]} to {new_deq_args[1]}" 870 ) 871 872 node.replace_all_uses_with(deq_node) 873 deq_node.args = tuple(new_deq_args) 874 875 graph_module.graph.erase_node(node) 876 graph_module.graph.erase_node(mplier_node) 877 graph_module.recompile() 878 879 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 880 for node in graph_module.graph.nodes: 881 self.attempt_fusion(graph_module, node) 882 result = super().call(graph_module) 883 return result 884 885 886@register_cadence_pass(CadencePassAttribute(opt_level=1)) 887class FuseTransposeOpPairsPass(FuseOpPairsAcrossBranchesPass): 888 """ 889 Fuse dequantize-quantize op pairs to a single requantize op. 890 For the special case where quant params match, this will remove 891 both dequant and quant ops. 892 """ 893 894 # A list of ops that can be bypassed when looking for a 895 # dequantize->quantize chain 896 bypass_ops: set[EdgeOpOverload] = { 897 exir_ops.edge.cadence.quantize_per_tensor.default, 898 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 899 exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 900 exir_ops.edge.cadence.dequantize_per_tensor.default, 901 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 902 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 903 } 904 905 def can_fuse_for_chain( 906 self, 907 producer: torch.fx.Node, 908 consumer: torch.fx.Node, 909 consumer_op_packets: set[EdgeOpOverloadPacket], 910 ) -> bool: 911 if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets): 912 return False 913 914 def get_dims(node: torch.fx.Node) -> tuple[int, int]: 915 def canonicalize(dim: int) -> int: 916 if dim < 0: 917 dim += len(node.meta["val"].shape) 918 return dim 919 920 return tuple(canonicalize(cast(int, d)) for d in node.args[1:3]) 921 922 def is_equivalent( 923 shape: Sequence[int], 924 transpose0: tuple[int, int], 925 transpose1: tuple[int, int], 926 ) -> bool: 927 def permute_order( 928 order: Sequence[int], dims: tuple[int, int] 929 ) -> Sequence[int]: 930 new_order = list(order) 931 new_order[dims[0]], new_order[dims[1]] = ( 932 new_order[dims[1]], 933 new_order[dims[0]], 934 ) 935 return new_order 936 937 order = permute_order(range(len(shape)), transpose0) 938 order = permute_order(order, transpose1) 939 940 non_unit_dims = [dim for dim in range(len(shape)) if shape[dim] != 1] 941 non_unit_dims_permuted = [dim for dim in order if shape[dim] != 1] 942 943 return non_unit_dims == non_unit_dims_permuted 944 945 return is_equivalent( 946 cast(torch.fx.Node, producer.args[0]).meta["val"].shape, 947 get_dims(producer), 948 get_dims(consumer), 949 ) 950 951 def get_fused_node( 952 self, 953 producer: torch.fx.Node, 954 consumer: torch.fx.Node, 955 graph_module: torch.fx.GraphModule, 956 ) -> torch.fx.Node: 957 output_shape = consumer.meta["val"].shape 958 with graph_module.graph.inserting_after(consumer): 959 view = graph_module.graph.call_function( 960 exir_ops.edge.aten.view_copy.default, 961 (consumer.args[0], output_shape), 962 {}, 963 ) 964 return view 965 966 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 967 # Remove any dequantize op that has only quantize ops as its users. 968 self.find_and_fuse( 969 graph_module, 970 producer_op_packets={exir_ops.edge.aten.transpose_copy}, 971 consumer_op_packets={exir_ops.edge.aten.transpose_copy}, 972 bypass_ops=self.bypass_ops, 973 ) 974 result = super().call(graph_module) 975 return result 976 977 978@register_cadence_pass(CadencePassAttribute(opt_level=1)) 979class FuseFullThenReshapePass(ExportPass): 980 """ 981 A pass that fuses a chain of full and reshape-like operations into a single full operation. 982 """ 983 984 fusion_candidates: set[EdgeOpOverload] = { 985 exir_ops.edge.aten.transpose_copy.int, 986 exir_ops.edge.aten.permute_copy.default, 987 exir_ops.edge.aten.view_copy.default, 988 } 989 990 def call_operator( 991 self, 992 op, 993 args: tuple[Argument, ...], 994 kwargs: dict[str, Argument], 995 meta: NodeMetadata, 996 ) -> ProxyValue: 997 if op not in self.fusion_candidates: 998 return super().call_operator(op, args, kwargs, meta) 999 1000 full_node = cast(ProxyValue, args[0]).node 1001 if not ( 1002 full_node.op == "call_function" 1003 and full_node.target == exir_ops.edge.aten.full.default 1004 ): 1005 # full -> self.fusion_candidates. 1006 return super().call_operator(op, args, kwargs, meta) 1007 1008 fill_value = full_node.args[1] 1009 return super().call_operator( 1010 exir_ops.edge.aten.full.default, 1011 ( 1012 meta["val"].shape, 1013 fill_value, 1014 ), 1015 {"dtype": meta["val"].dtype}, 1016 meta, 1017 ) 1018 1019 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 1020 graph_module = super().call(graph_module).graph_module 1021 graph_module.graph.eliminate_dead_code() 1022 return PassResult(graph_module, True) 1023 1024 1025class CadenceFuseOpsInGraph: 1026 passes = [ 1027 FuseMMWithAdd, 1028 FuseBatchNormWithConv, 1029 FuseQuantizedBatchNormWithConv, 1030 FuseCascadedTransposeOrPermuteOps, 1031 FuseCascadedViewOps, 1032 FuseQuantDequantToRequantizePass, 1033 FuseMulIntoDequantPass, 1034 FuseFullThenReshapePass, 1035 FuseTransposeOpPairsPass, 1036 ] 1037