1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3 4import unittest 5 6import executorch.backends.cadence.aot.ops_registrations # noqa 7import torch 8from executorch.backends.cadence.aot import compiler 9from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2 10from executorch.backends.cadence.aot.fuse_ops import ( 11 FuseFullThenReshapePass, 12 FuseMulIntoDequantPass, 13 FuseQuantDequantToRequantizePass, 14 FuseTransposeOpPairsPass, 15) 16from executorch.backends.cadence.aot.graph_builder import GraphBuilder 17from executorch.backends.cadence.aot.pass_utils import count_node 18from executorch.exir.dialects._ops import ops as exir_ops 19from executorch.exir.dialects.edge._ops import EdgeOpOverload 20from torch import nn 21 22 23class TestFusionPassesBase(unittest.TestCase): 24 def check_op_counts( 25 self, 26 graph_module: torch.fx.GraphModule, 27 expected_op_counts: dict[EdgeOpOverload, int], 28 ) -> None: 29 for op, count in expected_op_counts.items(): 30 self.assertEqual(count_node(graph_module, op), count) 31 32 33class TestFusionPasses(TestFusionPassesBase): 34 def test_addmm_fusion(self): 35 class AddmmFeasible1(torch.nn.Module): 36 def forward(self, x, y, z): 37 t1 = torch.mm(x, y) 38 return torch.add(t1, z) 39 40 x = torch.randn(3, 5) 41 y = torch.randn(5, 6) 42 z = torch.randn(6) 43 44 graph_module = ( 45 compiler.export_to_cadence(AddmmFeasible1(), (x, y, z)) 46 .exported_program() 47 .graph_module 48 ) 49 graph_module.graph.eliminate_dead_code() 50 51 # Assert that mm and add were fused to addmm 52 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) 53 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) 54 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) 55 56 class AddmmFeasible2(torch.nn.Module): 57 def forward(self, x, y, z): 58 t1 = y.view((8, 6)) 59 t2 = torch.mm(x, t1) 60 t3 = t2.view((2, 2, 6)) 61 return torch.add(t3, z) 62 63 x = torch.randn(4, 8) 64 y = torch.randn(2, 4, 6) 65 z = torch.randn(6) 66 67 graph_module = ( 68 compiler.export_to_cadence(AddmmFeasible2(), (x, y, z)) 69 .exported_program() 70 .graph_module 71 ) 72 graph_module.graph.eliminate_dead_code() 73 # Assert that mm and add were fused to addmm 74 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) 75 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) 76 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) 77 78 # Bias is a singleton value, broadcastable to output of mm 79 class AddmmFeasible3(torch.nn.Module): 80 def forward(self, x, y): 81 t1 = torch.mm(x, y) 82 return torch.add(t1, torch.ones(1)) 83 84 x = torch.randn(3, 5) 85 y = torch.randn(5, 6) 86 87 graph_module = ( 88 compiler.export_to_cadence(AddmmFeasible3(), (x, y)) 89 .exported_program() 90 .graph_module 91 ) 92 graph_module.graph.eliminate_dead_code() 93 # Assert that mm and add were fused to addmm 94 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1) 95 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0) 96 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0) 97 98 # Bias is not broadcastable to output of mm 99 class AddmmInfeasible1(torch.nn.Module): 100 def forward(self, x, y, z): 101 t1 = y.view((8, 6)) 102 t2 = torch.mm(x, t1) 103 t3 = t2.view((2, 2, 6)) 104 return torch.add(t3, z) 105 106 x = torch.randn(4, 8) 107 y = torch.randn(2, 4, 6) 108 z = torch.randn(2, 2, 1) 109 110 graph_module = ( 111 compiler.export_to_cadence(AddmmInfeasible1(), (x, y, z)) 112 .exported_program() 113 .graph_module 114 ) 115 graph_module.graph.eliminate_dead_code() 116 # Assert that mm and add were not fused to addmm, since z cannot be 117 # broadcasted to the out of mm. 118 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 1) 119 120 # The add consuming the output of mm has more than one users. 121 class AddmmInfeasible2(torch.nn.Module): 122 def forward(self, x, y, z): 123 t1 = torch.mm(x, y) 124 t2 = torch.add(t1, z) 125 t3 = torch.add(t2, z) 126 return torch.add(t2, t3) 127 128 x = torch.randn(3, 5) 129 y = torch.randn(5, 6) 130 z = torch.randn(6) 131 132 graph_module = ( 133 compiler.export_to_cadence(AddmmInfeasible2(), (x, y, z)) 134 .exported_program() 135 .graph_module 136 ) 137 graph_module.graph.eliminate_dead_code() 138 # Assert that mm and add were not fused to addmm, since add has multiple 139 # users. 140 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 3) 141 142 # TODO(matthiascremon): enable that pass with new flow 143 @torch.no_grad() 144 @unittest.expectedFailure 145 def test_legacy_conv_bn_fusion(self): 146 class ModelConvBN(torch.nn.Module): 147 def __init__(self, in_features: int, out_features: int, kernel_size: int): 148 super().__init__() 149 self.conv1d = nn.Conv1d(in_features, out_features, kernel_size) 150 self.bn = nn.BatchNorm1d(out_features) 151 152 def forward(self, x): 153 y = self.conv1d(x) 154 return self.bn(y) 155 156 model = ModelConvBN(64, 1, 2) 157 x = torch.randn(1, 64, 4) 158 159 graph_module = ( 160 compiler.export_to_executorch(model.eval(), (x,)) 161 .exported_program() 162 .exported_program() 163 .graph_module 164 ) 165 # Assert that after running the fusion passes, batchnorm was fused with conv1d 166 self.assertEqual( 167 count_node(graph_module, torch.ops.aten.linear.out) 168 + count_node(graph_module, torch.ops.cadence.convolution.out), 169 1, 170 ) 171 self.assertEqual( 172 count_node( 173 graph_module, torch.ops.aten._native_batch_norm_legit_no_training.out 174 ), 175 0, 176 ) 177 178 def test_permute_transpose_fusion(self): 179 class PermuteTranspose(torch.nn.Module): 180 def forward(self, x): 181 y = x.permute((0, 2, 4, 1, 3)) 182 return y.transpose(0, 1) 183 184 x = torch.randn(3, 1, 3, 1, 4) 185 graph_module = ( 186 compiler.export_to_cadence(PermuteTranspose(), (x,)) 187 .exported_program() 188 .graph_module 189 ) 190 graph_module.graph.eliminate_dead_code() 191 # Assert that permute op was fused with transpose op 192 self.assertEqual( 193 count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 1 194 ) 195 self.assertEqual( 196 count_node(graph_module, exir_ops.edge.aten.transpose_copy.int), 0 197 ) 198 199 def test_view_fusion(self): 200 class ViewFusion(torch.nn.Module): 201 def forward(self, x): 202 x = x.view([1, 8, 15]) 203 x = x.view([1, 1, 120]) 204 return x.view([1, 12, 10]) 205 206 x = torch.randn(8, 5, 3) 207 graph_module = ( 208 compiler.export_to_cadence(ViewFusion(), (x,)) 209 .exported_program() 210 .graph_module 211 ) 212 graph_module.graph.eliminate_dead_code() 213 # Assert that only one view op remains 214 self.assertEqual( 215 count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1 216 ) 217 218 def test_force_quant_dequant_fusion(self): 219 class M(torch.nn.Module): 220 def __init__(self): 221 super().__init__() 222 223 def forward(self, x): 224 x = torch.ops.quantized_decomposed.quantize_per_tensor( 225 x, 1.2, 3, 0, 127, torch.int8 226 ) 227 x = torch.permute(x, [2, 0, 1, 3]) 228 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 229 x, 4.5, 6, 0, 127, torch.int8 230 ) 231 return x 232 233 inputs = torch.randn(2, 12, 1, 6) 234 model = M() 235 graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module 236 237 graph_module = FuseQuantDequantToRequantizePass( 238 force_quant_dequant_fusion=True 239 )(graph_module).graph_module 240 self.check_op_counts( 241 graph_module, 242 expected_op_counts={ 243 # Verify that no dequant/quant pair was replaced with requantize. 244 # quantize -> permute -> dequantize should not be replaced with requantize. 245 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, 246 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, 247 exir_ops.edge.cadence.requantize.default: 1, 248 }, 249 ) 250 251 def test_no_replace_quant_permute_dequant_with_requantize(self): 252 class M(torch.nn.Module): 253 def __init__(self): 254 super().__init__() 255 256 def forward(self, x): 257 x = torch.ops.quantized_decomposed.quantize_per_tensor( 258 x, 1.2, 3, 0, 127, torch.int8 259 ) 260 x = torch.permute(x, [2, 0, 1, 3]) 261 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 262 x, 4.5, 6, 0, 127, torch.int8 263 ) 264 return x 265 266 inputs = torch.randn(2, 12, 1, 6) 267 model = M() 268 graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module 269 270 graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module 271 self.check_op_counts( 272 graph_module, 273 expected_op_counts={ 274 # Verify that no dequant/quant pair was replaced with requantize. 275 # quantize -> permute -> dequantize should not be replaced with requantize. 276 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, 277 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, 278 exir_ops.edge.cadence.requantize.default: 0, 279 }, 280 ) 281 282 def test_replace_quant_view_dequant_with_requantize(self): 283 class M(torch.nn.Module): 284 def __init__(self): 285 super().__init__() 286 287 def forward(self, x): 288 x = torch.ops.quantized_decomposed.quantize_per_tensor( 289 x, 1.2, 3, 0, 127, torch.int8 290 ) 291 x = x.view(-1) 292 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 293 x, 4.5, 6, 0, 127, torch.int8 294 ) 295 return x 296 297 inputs = torch.randn(2, 12, 1, 6) 298 model = M() 299 graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module 300 graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module 301 graph_module.print_readable() 302 303 self.check_op_counts( 304 graph_module, 305 expected_op_counts={ 306 # Verify that no dequant/quant pair was replaced with requantize. 307 # quantize -> permute -> dequantize should not be replaced with requantize. 308 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, 309 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, 310 exir_ops.edge.cadence.requantize.default: 1, 311 }, 312 ) 313 314 def test_replace_dequant_quant_with_requantize(self): 315 class M(torch.nn.Module): 316 def __init__(self): 317 super().__init__() 318 319 def forward(self, x): 320 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 321 x, 1.2, 3, 0, 127, torch.int8 322 ) 323 x = torch.permute(x, [2, 0, 1, 3]) 324 x = torch.ops.quantized_decomposed.quantize_per_tensor( 325 x, 4.5, 6, 0, 127, torch.int8 326 ) 327 return x 328 329 inputs = torch.randn(2, 12, 1, 6).to(torch.int8) 330 model = M() 331 graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module 332 graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module 333 334 self.check_op_counts( 335 graph_module, 336 expected_op_counts={ 337 # Verify that dequant -> permute -> quant was replaced with permute -> requantize. 338 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, 339 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, 340 exir_ops.edge.cadence.requantize.default: 1, 341 }, 342 ) 343 344 def test_replace_dequant_permute_quant_with_requantize(self): 345 class M(torch.nn.Module): 346 def __init__(self): 347 super().__init__() 348 349 def forward(self, x): 350 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 351 x, 1.2, 3, 0, 127, torch.int8 352 ) 353 x = torch.permute(x, [2, 0, 1, 3]) 354 x = torch.ops.quantized_decomposed.quantize_per_tensor( 355 x, 4.5, 6, 0, 127, torch.int8 356 ) 357 return x 358 359 inputs = torch.randn(2, 12, 1, 6).to(torch.int8) 360 model = M() 361 graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module 362 graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module 363 364 self.check_op_counts( 365 graph_module, 366 expected_op_counts={ 367 # Verify that dequant -> permute -> quant was replaced with permute -> requantize. 368 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0, 369 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0, 370 exir_ops.edge.cadence.requantize.default: 1, 371 }, 372 ) 373 374 def test_remove_nop_dequant_quant(self): 375 class M(torch.nn.Module): 376 def __init__(self): 377 super(M, self).__init__() 378 self.lin1 = torch.nn.Linear(6, 12, bias=False) 379 self.lin2 = torch.nn.Linear(12, 24, bias=False) 380 381 def forward(self, x): 382 x = self.lin1(x) 383 # redundant dequant+quant will be created around this permute 384 x = torch.permute(x, [0, 2, 1, 3]) 385 x = self.lin2(x) 386 return x 387 388 inputs = torch.randn(2, 12, 1, 6) 389 model = M() 390 quantized_model = quantize_pt2(model, (inputs,)) 391 graph_module = ( 392 export_to_edge(quantized_model, (inputs,)).exported_program().graph_module 393 ) 394 graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module 395 self.check_op_counts( 396 graph_module, 397 expected_op_counts={ 398 # Verify that one dequant/quant pair was removed 399 # Expect 1 quantize ops: 1 input 400 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, 401 # Expect 1 dequant op at the end (output of second linear) 402 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, 403 }, 404 ) 405 406 def test_fuse_mul_into_dequant(self): 407 class M(torch.nn.Module): 408 def forward(self, x): 409 x0 = torch.ops.quantized_decomposed.dequantize_per_tensor( 410 x, 1.5, 0, 0, 255, torch.uint8 411 ) 412 x1 = torch.full([4, 32], 3, dtype=torch.float32) 413 x2 = x0 * x1 414 return x2 415 416 inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),) 417 graph_module = export_to_edge(M(), inputs).exported_program().graph_module 418 graph_module = FuseMulIntoDequantPass()(graph_module).graph_module 419 420 # verify that the mul and full ops were removed 421 self.check_op_counts( 422 graph_module, 423 expected_op_counts={ 424 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1, 425 exir_ops.edge.aten.full.default: 0, 426 exir_ops.edge.aten.mul.Tensor: 0, 427 }, 428 ) 429 430 # verify that the dequant scale value was updated correctly 431 for node in graph_module.graph.nodes: 432 if ( 433 node.target 434 == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 435 ): 436 deq_scale = node.args[1] 437 self.assertEqual(deq_scale, 4.5) 438 439 def test_fuse_then_transpose_pass(self): 440 # Create a graph with full -> transpose. 441 builder = GraphBuilder() 442 full_node = builder.call_operator( 443 op=exir_ops.edge.aten.full.default, args=((2, 3), 1) 444 ) 445 transpose_node = builder.call_operator( 446 op=exir_ops.edge.aten.transpose_copy.int, 447 args=(full_node, 0, 1), 448 ) 449 permute_node = builder.call_operator( 450 op=exir_ops.edge.aten.permute_copy.default, 451 args=(transpose_node, (1, 0)), 452 ) 453 view_node = builder.call_operator( 454 op=exir_ops.edge.aten.view_copy.default, 455 args=(permute_node, (1, 6, 1)), 456 ) 457 builder.output(view_node) 458 gm = builder.get_graph_module() 459 self.check_op_counts( 460 gm, 461 expected_op_counts={ 462 exir_ops.edge.aten.full.default: 1, 463 exir_ops.edge.aten.transpose_copy.int: 1, 464 exir_ops.edge.aten.permute_copy.default: 1, 465 exir_ops.edge.aten.view_copy.default: 1, 466 }, 467 ) 468 469 # Check that the pass fuses the full with all other ops (transpose, permute, view). 470 gm_after_pass = FuseFullThenReshapePass()(gm).graph_module 471 self.check_op_counts( 472 gm_after_pass, 473 expected_op_counts={ 474 exir_ops.edge.aten.full.default: 1, 475 exir_ops.edge.aten.transpose_copy.int: 0, 476 exir_ops.edge.aten.permute_copy.default: 0, 477 exir_ops.edge.aten.view_copy.default: 0, 478 }, 479 ) 480 481 482class TestFuseTransposeOpPairsPass(TestFusionPassesBase): 483 def test_fuse_transpose_pairs(self): 484 # Create a graph with transpose -> quant -> transpose. 485 builder = GraphBuilder() 486 x = builder.placeholder("x", torch.randn(2, 3)) 487 transpose_node = builder.call_operator( 488 op=exir_ops.edge.aten.transpose_copy.int, 489 args=(x, 0, 1), 490 ) 491 quant_node = builder.call_operator( 492 op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 493 args=(transpose_node, 1.2, 3, 0, 127, torch.int8), 494 ) 495 transpose_node = builder.call_operator( 496 op=exir_ops.edge.aten.transpose_copy.int, 497 args=(quant_node, 0, 1), 498 ) 499 builder.output(transpose_node) 500 gm = builder.get_graph_module() 501 self.check_op_counts( 502 gm, 503 expected_op_counts={ 504 exir_ops.edge.aten.transpose_copy.int: 2, 505 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, 506 }, 507 ) 508 509 # Check that the pass fuses the two transpose ops. 510 gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module 511 self.check_op_counts( 512 gm_after_pass, 513 expected_op_counts={ 514 exir_ops.edge.aten.transpose_copy.int: 0, 515 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, 516 }, 517 ) 518 519 def test_no_fusion_for_transpose_pairs(self): 520 # Create a graph with transpose -> quant -> transpose. 521 builder = GraphBuilder() 522 x = builder.placeholder("x", torch.randn(2, 3, 4)) 523 transpose_node = builder.call_operator( 524 op=exir_ops.edge.aten.transpose_copy.int, 525 args=(x, 0, 1), 526 ) 527 quant_node = builder.call_operator( 528 op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 529 args=(transpose_node, 1.2, 3, 0, 127, torch.int8), 530 ) 531 transpose_node = builder.call_operator( 532 op=exir_ops.edge.aten.transpose_copy.int, 533 args=(quant_node, 1, 2), 534 ) 535 builder.output(transpose_node) 536 gm = builder.get_graph_module() 537 self.check_op_counts( 538 gm, 539 expected_op_counts={ 540 exir_ops.edge.aten.transpose_copy.int: 2, 541 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, 542 }, 543 ) 544 545 # No fusion. 546 gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module 547 self.check_op_counts( 548 gm_after_pass, 549 expected_op_counts={ 550 exir_ops.edge.aten.transpose_copy.int: 2, 551 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1, 552 }, 553 ) 554 555 def test_fusion_for_forked_transposes(self): 556 # Create a graph with transpose -> quant -> transpose. 557 builder = GraphBuilder() 558 x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32)) 559 transpose_node = builder.call_operator( 560 op=exir_ops.edge.aten.transpose_copy.int, 561 args=(x, 0, 1), 562 ) 563 num_forks = 3 564 outputs = [] 565 for _ in range(num_forks): 566 quant_node = builder.call_operator( 567 op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 568 args=(transpose_node, 1.2, 3, 0, 127, torch.int8), 569 ) 570 outputs.append( 571 builder.call_operator( 572 op=exir_ops.edge.aten.transpose_copy.int, 573 args=(quant_node, 0, 1), 574 ) 575 ) 576 builder.output(outputs) 577 gm = builder.get_graph_module() 578 self.check_op_counts( 579 gm, 580 expected_op_counts={ 581 exir_ops.edge.aten.transpose_copy.int: num_forks + 1, 582 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: num_forks, 583 }, 584 ) 585 586 # Fuse the all the transpose ops. 587 gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module 588 self.check_op_counts( 589 gm_after_pass, 590 expected_op_counts={ 591 exir_ops.edge.aten.transpose_copy.int: 0, 592 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: num_forks, 593 }, 594 ) 595