1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3import unittest 4from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union 5 6import torch 7import torch.nn.functional as F 8from executorch.backends.cadence.aot import compiler 9from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2 10from executorch.backends.cadence.aot.graph_builder import single_op_builder 11from executorch.backends.cadence.aot.pass_utils import count_node 12from executorch.backends.cadence.aot.replace_ops import ( 13 ForceChannelLastForConvPass, 14 MakeSliceAndCatDimOutermostPass, 15 ReplaceAddMMWithLinearPass, 16 ReplaceAtenConvolutionWithJarvisConvolutionPass, 17 ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass, 18 ReplaceConstantPadNdWithSlicePass, 19 ReplaceConvolutionOptionalArgsWithConcreteArgsPass, 20 ReplaceConvWithIm2RowAndLinear, 21 ReplaceFunctionallyEquivalentOpTargets, 22 ReplaceIm2RowWithViewPass, 23 ReplaceLinearWithFullyConnectedOpPass, 24 ReplaceMMWithAddMMPass, 25 ReplaceNopTransposeOrPermuteWithViewPass, 26 ReplacePadWithCatPass, 27 ReplacePermuteWithTransposePass, 28 ReplaceRepeatWithCatPass, 29 ReplaceScalarTensorWithFullPass, 30 ReplaceScalarWithTensorArgPass, 31 ReplaceSelectWithViewOpPass, 32 ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass, 33 ReplaceSqueezeAndUnsqueezeWithViewPass, 34 ReplaceTCopyWithTransposePass, 35 ReplaceTransposedConvWithLinearPass, 36 ReplaceTrivialConvWithLinear, 37) 38from executorch.exir.dialects._ops import ops as exir_ops 39from executorch.exir.pass_base import ExportPass 40from executorch.exir.passes import dead_code_elimination_pass 41 42from parameterized.parameterized import parameterized 43from torch._ops import OpOverload 44from torch.fx.passes.infra.pass_base import PassResult 45 46 47class TestReplaceOpsPasses(unittest.TestCase): 48 def assertTargetCountEqual( 49 self, 50 graph_module: torch.fx.GraphModule, 51 target: Union[Callable[..., Any], str], 52 expected_count: int, 53 ): 54 """Helper function to check the number of nodes with a given target.""" 55 actual_count = count_node(graph_module, target) 56 self.assertEqual( 57 actual_count, 58 expected_count, 59 f"{target} count mismatch for graph {graph_module}", 60 ) 61 62 def assertTargetCountsEqual( 63 self, 64 graph_module: torch.fx.GraphModule, 65 targets_and_counts: List[Tuple[Union[Callable[..., Any], str], int]], 66 ): 67 """Helper function to check the number of nodes of all types for a given target.""" 68 for target, expected_count in targets_and_counts: 69 self.assertTargetCountEqual(graph_module, target, expected_count) 70 71 @parameterized.expand( 72 [ 73 [(3, 5), (0, 0)], 74 [ 75 (20, 1, 80), 76 (0, 0), 77 ], 78 ] 79 ) 80 @torch.no_grad() 81 def test_replace_constant_pad_nd_with_slice( 82 self, shape: Tuple[int], padding: Tuple[int] 83 ): 84 # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. 85 class Padding(torch.nn.Module): 86 def __init__(self): 87 super().__init__() 88 self.padding = padding 89 90 def forward(self, x: torch.Tensor): 91 return F.pad(x, self.padding) 92 93 model = Padding() 94 x = torch.randn(shape) 95 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 96 97 p = ReplaceConstantPadNdWithSlicePass() 98 99 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 100 self.assertEqual( 101 count_node(graph_after_passes, exir_ops.edge.aten.slice.Tensor), 102 1, 103 ) 104 105 self.assertEqual( 106 count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), 107 0, 108 ) 109 110 @parameterized.expand( 111 [ 112 [(7, 5, 6), 1.23], 113 [(7, 5), 2], 114 ] 115 ) 116 @torch.no_grad() 117 def test_add_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): 118 class Add(torch.nn.Module): 119 def forward(self, x): 120 return torch.ops.aten.add.Scalar(x, other) 121 122 model = Add() 123 x = torch.randn(shape) 124 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 125 126 p = ReplaceScalarWithTensorArgPass() 127 128 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 129 self.assertEqual( 130 count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), 131 1, 132 ) 133 134 self.assertEqual( 135 count_node(graph_after_passes, exir_ops.edge.aten.add.Scalar), 136 0, 137 ) 138 139 @parameterized.expand( 140 [ 141 [(7, 5, 6), 1.23], 142 [(7, 5), 2], 143 [(10), 42949], 144 ] 145 ) 146 @torch.no_grad() 147 def test_sub_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): 148 class Sub(torch.nn.Module): 149 def forward(self, x): 150 return torch.ops.aten.sub.Scalar(x, other) 151 152 model = Sub() 153 x = torch.randn(shape) 154 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 155 156 p = ReplaceScalarWithTensorArgPass() 157 158 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 159 self.assertEqual( 160 count_node(graph_after_passes, exir_ops.edge.aten.sub.Tensor), 161 1, 162 ) 163 164 self.assertEqual( 165 count_node(graph_after_passes, exir_ops.edge.aten.sub.Scalar), 166 0, 167 ) 168 169 @parameterized.expand( 170 [ 171 [(7, 5, 6), 1.23], 172 [(7, 5), 2], 173 [(513), 3], 174 ] 175 ) 176 @torch.no_grad() 177 def test_mul_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float): 178 class Mul(torch.nn.Module): 179 def forward(self, x): 180 return torch.ops.aten.mul.Scalar(x, other) 181 182 model = Mul() 183 x = torch.randn(shape) 184 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 185 186 p = ReplaceScalarWithTensorArgPass() 187 188 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 189 self.assertEqual( 190 count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), 191 1, 192 ) 193 194 self.assertEqual( 195 count_node(graph_after_passes, exir_ops.edge.aten.mul.Scalar), 196 0, 197 ) 198 199 @parameterized.expand( 200 [ 201 [(7, 5, 6), 1.23], 202 [(7, 5), 2], 203 ] 204 ) 205 @torch.no_grad() 206 def test_div_replace_scalar_with_tensor_arg( 207 self, 208 shape: Tuple[int], 209 other: float, 210 ): 211 class Div(torch.nn.Module): 212 def forward(self, x): 213 return torch.ops.aten.div.Scalar(x, other) 214 215 model = Div() 216 x = torch.randn(shape) 217 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 218 219 p = ReplaceScalarWithTensorArgPass() 220 221 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 222 self.assertEqual( 223 count_node(graph_after_passes, exir_ops.edge.aten.div.Tensor), 224 1, 225 ) 226 227 self.assertEqual( 228 count_node(graph_after_passes, exir_ops.edge.aten.div.Scalar), 229 0, 230 ) 231 232 @parameterized.expand( 233 [ 234 [(2, 3, 5, 6)], 235 [(7, 6, 5)], 236 [(4, 4)], 237 [(316)], 238 ] 239 ) 240 @torch.no_grad() 241 def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]): 242 model = torch.nn.ReLU() 243 x = torch.randn(shape) 244 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 245 p = ReplaceFunctionallyEquivalentOpTargets() 246 247 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 248 self.assertEqual( 249 count_node(graph_after_passes, exir_ops.edge.aten.relu.default), 250 1, 251 ) 252 self.assertEqual( 253 count_node(graph_after_passes, exir_ops.edge.aten.relu_.default), 254 0, 255 ) 256 257 @parameterized.expand( 258 [ 259 # split the only dimension 260 [(50,), i, 0] 261 for i in range(2, 7) 262 ] 263 + [ 264 # split the leading dim 265 [(10, 2, 3), i, 0] 266 for i in range(2, 7) 267 ] 268 + [ 269 # split the trailing dim 270 [(3, 3, 6), i, 2] 271 for i in range(2, 6) 272 ] 273 + [ 274 # split the dim in the middle 275 [(3, 5, 14, 2, 3), i, 2] 276 for i in range(2, 7) 277 ] 278 ) 279 @torch.no_grad() 280 def test_replace_functionally_equivalent_op_targets_unsafe_split( 281 self, shape: Tuple[int], split_size: int, dim: int 282 ): 283 class TensorSplitWithSizes(torch.nn.Module): 284 def __init__(self, split_size: int, dim: int, op: OpOverload): 285 super().__init__() 286 self.split_size = split_size 287 self.dim = dim 288 self.op = op 289 290 def forward(self, x: torch.Tensor): 291 return self.op(x, self.split_size, self.dim) 292 293 x = torch.randn(shape) 294 model = TensorSplitWithSizes(split_size, dim, torch.unsafe_split) 295 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 296 p = ReplaceFunctionallyEquivalentOpTargets() 297 298 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 299 self.assertEqual( 300 count_node( 301 graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default 302 ), 303 1, 304 ) 305 self.assertEqual( 306 count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 307 0, 308 ) 309 310 @parameterized.expand( 311 [ 312 [(16, 32)], 313 [(1, 240)], 314 [(4, 16)], 315 ] 316 ) 317 @torch.no_grad() 318 def test_replace_t_copy_with_transpose(self, shape: Tuple[int]): 319 class TCopy(torch.nn.Module): 320 def forward(self, x: torch.Tensor): 321 return exir_ops.edge.aten.t_copy(x) 322 323 w = torch.randn(shape) 324 inputs = (w,) 325 p1 = ReplaceTCopyWithTransposePass() 326 p2 = ReplacePermuteWithTransposePass() 327 model = TCopy() 328 graph_module = export_to_edge(model, inputs).exported_program().graph_module 329 graph_after_passes = cast( 330 PassResult, p2(cast(PassResult, p1(graph_module)).graph_module) 331 ).graph_module 332 self.assertEqual( 333 count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 334 1, 335 ) 336 self.assertEqual( 337 count_node(graph_after_passes, exir_ops.edge.aten.t_copy), 338 0, 339 ) 340 341 @parameterized.expand( 342 [ 343 [(1, 8, 33), 8, 16, 3], 344 [(1, 8, 33), 8, 16, 5, 2], 345 [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], 346 # channel last 347 [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True], 348 [(1, 33, 8), 8, 16, 5, 2, 0, 1, False, True, True], 349 ] 350 ) 351 @torch.no_grad() 352 def test_replace_transposed_conv_with_linear( 353 self, 354 shape: Tuple[int], 355 in_channels: int, 356 out_channels: int, 357 kernel: int, 358 stride: int = 1, 359 padding: int = 0, 360 dilation: int = 1, 361 depthwise: bool = False, 362 bias: bool = True, 363 channel_last: bool = False, 364 ): 365 class TConv(torch.nn.Module): 366 def __init__(self): 367 super().__init__() 368 self.tconv1d = torch.nn.ConvTranspose1d( 369 in_channels, 370 out_channels, 371 kernel, 372 stride=stride, 373 padding=padding, 374 dilation=dilation, 375 groups=in_channels if depthwise else 1, 376 bias=bias, 377 ) 378 379 def forward(self, x: torch.Tensor): 380 if channel_last: 381 x = x.permute([0, 2, 1]) 382 x = self.tconv1d(x) 383 if channel_last: 384 x = x.permute([0, 2, 1]) 385 return x 386 387 x = torch.randn(shape) 388 model = TConv() 389 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 390 p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() 391 p2 = ReplaceTransposedConvWithLinearPass() 392 graph_after_passes = cast( 393 PassResult, p2(cast(PassResult, p1(graph_module)).graph_module) 394 ).graph_module 395 self.assertEqual( 396 count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 397 1, 398 ) 399 self.assertEqual( 400 count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), 401 0, 402 ) 403 self.assertEqual( 404 count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 405 0, 406 ) 407 408 @parameterized.expand( 409 [ 410 [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False], 411 # # depthwise 412 [(1, 8, 33), 8, 16, 3, 1, 0, 1, True, False, False], 413 [(1, 8, 33), 8, 16, 3, 2, 4, 3, True, False, False], 414 # channel last (uses a permute op before calling conv1d) 415 [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True], 416 [(1, 33, 8), 8, 16, 3, 2, 4, 3, True, False, True], 417 ] 418 ) 419 @torch.no_grad() 420 def test_replace_convolution_optional_args_with_concrete_args( 421 self, 422 shape: Tuple[int], 423 in_channels: int, 424 out_channels: int, 425 kernel: int, 426 stride: int = 1, 427 padding: int = 0, 428 dilation: int = 1, 429 depthwise: bool = False, 430 bias: bool = True, 431 channel_last: bool = False, 432 ): 433 class Conv(torch.nn.Module): 434 def __init__(self): 435 super().__init__() 436 self.conv1d = torch.nn.Conv1d( 437 in_channels, 438 out_channels, 439 kernel, 440 stride=stride, 441 padding=padding, 442 dilation=dilation, 443 groups=in_channels if depthwise else 1, 444 bias=bias, 445 ) 446 447 def forward(self, x: torch.Tensor): 448 if channel_last: 449 x = x.permute([0, 2, 1]) 450 x = self.conv1d(x) 451 if channel_last: 452 x = x.permute([0, 2, 1]) 453 return x 454 455 x = torch.randn(shape) 456 model = Conv() 457 458 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 459 460 p = ReplaceConvolutionOptionalArgsWithConcreteArgsPass() 461 462 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 463 self.assertEqual( 464 count_node(graph_after_passes, exir_ops.edge.aten.full.default), 465 1, 466 ) 467 self.assertEqual( 468 count_node(graph_after_passes, exir_ops.edge.aten.convolution.default), 469 1, 470 ) 471 472 @parameterized.expand( 473 [ 474 [(1, 2, 3), (1, 1)], 475 [ 476 (20, 1, 80), 477 (1, 4), 478 ], 479 ] 480 ) 481 @torch.no_grad() 482 def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]): 483 # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. 484 class Padding(torch.nn.Module): 485 def __init__(self): 486 super().__init__() 487 self.padding = padding 488 489 def forward(self, x: torch.Tensor): 490 return F.pad(x, self.padding) 491 492 model = Padding() 493 x = torch.randn(shape) 494 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 495 496 p = ReplacePadWithCatPass() 497 498 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 499 self.assertEqual( 500 count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 501 1, 502 ) 503 504 self.assertEqual( 505 count_node(graph_after_passes, exir_ops.edge.aten.pad.default), 506 0, 507 ) 508 509 @torch.no_grad() 510 def test_replace_repeat_with_cat(self): 511 class Repeat(torch.nn.Module): 512 def forward(self, x): 513 x1 = torch.add(x, 2.4, 3.1) 514 return torch.ops.aten.repeat(x1, [1, 2]) 515 516 x = torch.ones(3, 5) 517 graph_module = export_to_edge(Repeat(), (x,)).exported_program().graph_module 518 519 p = ReplaceRepeatWithCatPass() 520 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 521 self.assertEqual( 522 count_node(graph_after_passes, exir_ops.edge.aten.cat.default), 523 1, 524 ) 525 526 self.assertEqual( 527 count_node(graph_after_passes, exir_ops.edge.aten.repeat.default), 528 0, 529 ) 530 531 @parameterized.expand( 532 [ 533 # x, mask 534 [(1,)], 535 [(3, 4)], 536 [(7, 8, 3)], 537 [(3, 3, 2, 4)], 538 [(36, 1, 2, 80), (1)], 539 # tests where mask will be broadcasted 540 [(36, 1, 2, 80), (1, 1, 2, 1)], 541 [(36, 2, 8, 4), (36, 1, 1, 4)], 542 [(36, 2, 8, 4), (2, 1, 4)], 543 ] 544 ) 545 @torch.no_grad() 546 def test_replace_masked_scalar_tensor_with_full( 547 self, 548 shape: Tuple[int], 549 mask_shape: Union[Tuple[int, ...], None] = None, 550 ): 551 class MaskedFill(torch.nn.Module): 552 def __init__(self, value: float): 553 super().__init__() 554 self.value = value 555 556 def forward(self, x: torch.Tensor, mask: torch.Tensor): 557 return torch.masked_fill(x, mask, self.value) 558 559 x = torch.randn(shape) 560 mask = torch.randn(mask_shape if mask_shape else shape) > 0 561 value = 0.5 * torch.mean(x).item() 562 model = MaskedFill(value) 563 graph_module = export_to_edge(model, (x, mask)).exported_program().graph_module 564 565 p = ReplaceScalarTensorWithFullPass() 566 567 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 568 self.assertEqual( 569 count_node(graph_after_passes, exir_ops.edge.aten.full.default), 570 1, 571 ) 572 573 self.assertEqual( 574 count_node(graph_after_passes, exir_ops.edge.aten.where.self), 575 1, 576 ) 577 578 self.assertEqual( 579 count_node(graph_after_passes, exir_ops.edge.aten.masked_fill), 580 0, 581 ) 582 583 @parameterized.expand( 584 [ 585 [(1), 1.5], 586 [(1), 0.0], 587 ] 588 ) 589 @torch.no_grad() 590 def test_replace_scalar_tensor_with_full(self, shape: Tuple[int], value: float): 591 class ScalarTensor(torch.nn.Module): 592 def __init__(self, shape: Tuple[int], value: float): 593 super().__init__() 594 self.shape = shape 595 self.value = value 596 597 def forward(self, x: torch.Tensor): 598 return torch.scalar_tensor(value) 599 600 model = ScalarTensor(shape, value) 601 x = torch.randn(shape) 602 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 603 604 p = ReplaceScalarTensorWithFullPass() 605 606 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 607 self.assertEqual( 608 count_node(graph_after_passes, exir_ops.edge.aten.full.default), 609 1, 610 ) 611 612 self.assertEqual( 613 count_node(graph_after_passes, exir_ops.edge.aten.scalar_tensor.default), 614 0, 615 ) 616 617 @torch.no_grad() 618 def test_replace_linear_with_fully_connected(self): 619 shape, in_features, out_features, bias = (1, 14), 14, 128, False 620 model = torch.nn.Linear(in_features, out_features, bias=bias) 621 x = torch.randn(shape) 622 623 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 624 permute_to_trans_pass = ReplacePermuteWithTransposePass() 625 mm_to_addmm_pass = ReplaceMMWithAddMMPass() 626 add_to_linear_pass = ReplaceAddMMWithLinearPass() 627 linear_to_fullyconnected_pass = ReplaceLinearWithFullyConnectedOpPass() 628 graph_after_passes = linear_to_fullyconnected_pass( 629 add_to_linear_pass( 630 mm_to_addmm_pass( 631 permute_to_trans_pass(graph_module).graph_module 632 ).graph_module 633 ).graph_module 634 ).graph_module 635 self.assertIsNotNone(graph_after_passes) 636 637 self.assertEqual( 638 count_node(graph_after_passes, exir_ops.edge.aten.full.default), 639 1, 640 ) 641 642 self.assertEqual( 643 count_node( 644 graph_after_passes, exir_ops.edge.cadence.fully_connected.default 645 ), 646 1, 647 ) 648 649 self.assertEqual( 650 count_node(graph_after_passes, exir_ops.edge.aten.linear), 651 0, 652 ) 653 654 @parameterized.expand( 655 [ 656 [(4, 16, 256), 256, 512, True], 657 [(7, 17, 12), 12, 34, False], 658 ] 659 ) 660 @torch.no_grad() 661 def test_replace_addmm_with_linear( 662 self, shape: Tuple[int], in_features: int, out_features: int, bias: bool 663 ): 664 class AddMM(torch.nn.Module): 665 def __init__(self, alpha: float = 1, beta: float = 1): 666 super().__init__() 667 self.alpha = alpha 668 self.beta = beta 669 670 def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor): 671 return torch.addmm( 672 x, y, z.transpose(1, 0), alpha=self.alpha, beta=self.beta 673 ) 674 675 # alpha, beta must be 1 to be 1 to enable ReplaceAddMMWithLinearPass 676 # get_attr will always turn into placeholders and mutable outputs in PT2 677 M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0 678 x = torch.randn(N) 679 y = torch.randn(M, K) 680 z = torch.randn(N, K) 681 682 # test addmm 683 model = AddMM(alpha=alpha, beta=beta) 684 graph_module = export_to_edge(model, (x, y, z)).exported_program().graph_module 685 686 tp = ReplacePermuteWithTransposePass() 687 ap = ReplaceAddMMWithLinearPass() 688 graph_after_passes = cast( 689 PassResult, ap(cast(PassResult, tp(graph_module)).graph_module) 690 ).graph_module 691 self.assertIsNotNone(graph_after_passes) 692 693 self.assertEqual( 694 count_node(graph_module, exir_ops.edge.aten.addmm.default), 695 1, 696 ) 697 698 # Assert that all the aten.addmm nodes are removed. 699 self.assertEqual( 700 count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 701 1, 702 ) 703 self.assertEqual( 704 count_node(graph_after_passes, exir_ops.edge.aten.addmm.default), 705 0, 706 ) 707 708 @torch.no_grad() 709 def test_replace_mm_with_addmm(self): 710 # The mm ops will be convereted to addmm ops by Jarvis 711 class MM(torch.nn.Module): 712 def __init__(self, K, N): 713 super().__init__() 714 self.K = K 715 self.N = N 716 717 def forward(self, y: torch.Tensor, z: torch.Tensor): 718 return torch.ops.aten.mm(y, z) 719 720 M, K, N = 14, 48, 24 721 y = torch.randn(M, K) 722 z = torch.randn(K, N) 723 724 # test addmm 725 model = MM(K, N) 726 graph_module = export_to_edge(model, (y, z)).exported_program().graph_module 727 728 # First, replace the aten.mm with an aten.addmm op 729 p = ReplaceMMWithAddMMPass() 730 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 731 self.assertIsNotNone(graph_after_passes) 732 733 # Assert that all the aten.mm nodes are removed. 734 self.assertEqual( 735 count_node(graph_after_passes, exir_ops.edge.aten.addmm.default), 736 1, 737 ) 738 739 self.assertEqual( 740 count_node(graph_after_passes, exir_ops.edge.aten.mm), 741 0, 742 ) 743 744 @parameterized.expand( 745 [ 746 # shape 747 [(5, 1, 6, 7)], 748 [(1)], 749 [(4, 3, 2)], 750 # shape, dim to squeeze 751 [(2, 1), 0], 752 [(2, 7, 1, 3), 1], 753 [(2, 1, 3), 2], 754 ] 755 ) 756 @torch.no_grad() 757 def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None): 758 # The squeeze ops will be convereted to view ops by Jarvis 759 class Squeeze(torch.nn.Module): 760 def __init__(self, dim): 761 super().__init__() 762 self.dim = dim 763 764 def forward(self, x: torch.Tensor): 765 if self.dim is None: 766 return torch.squeeze(x) 767 return torch.squeeze(x, self.dim) 768 769 model = Squeeze(dim) 770 x = torch.randn(shape) 771 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 772 773 # First, replace the aten.squeeze_copy with an aten.view_copy op 774 p = ReplaceSqueezeAndUnsqueezeWithViewPass() 775 776 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 777 self.assertIsNotNone(graph_after_passes) 778 779 # Assert that all the aten.squeeze_copy nodes are removed. 780 self.assertEqual( 781 count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 782 1, 783 ) 784 self.assertEqual( 785 count_node(graph_after_passes, exir_ops.aten.squeeze_copy), 786 0, 787 ) 788 789 @parameterized.expand( 790 [ 791 # shape, dim to unsqueeze 792 [(5, 6, 7), 0], 793 [(5, 6, 7), -1], 794 [(5, 6, 7), 3], 795 [(5, 6, 7), 2], 796 ] 797 ) 798 @torch.no_grad() 799 def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int): 800 class Unsqueeze(torch.nn.Module): 801 def __init__(self, dim): 802 super().__init__() 803 self.dim = dim 804 805 def forward(self, x: torch.Tensor): 806 return torch.unsqueeze(x, self.dim) 807 808 # Test that the pass works for all dims. 809 model = Unsqueeze(dim) 810 x = torch.randn(5, 6, 7) 811 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 812 813 # First, replace the aten.unsqueeze_copy with an aten.view_copy op 814 p = ReplaceSqueezeAndUnsqueezeWithViewPass() 815 816 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 817 self.assertIsNotNone(graph_after_passes) 818 819 # Assert that all the aten.unsqueeze_copy nodes are removed. 820 self.assertEqual( 821 count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 822 1, 823 ) 824 self.assertEqual( 825 count_node(graph_after_passes, exir_ops.aten.unsqueeze_copy), 826 0, 827 ) 828 829 @torch.no_grad() 830 def test_replace_single_element_tensor_arguments_from_full_op_with_scalar( 831 self, 832 in_features: int = 16, 833 out_features: int = 16, 834 ): 835 # Tensors - these will be inputs to graph. 836 x = torch.randn([1, in_features]) 837 838 inputs = (x,) 839 model = torch.nn.Linear(in_features=in_features, out_features=out_features) 840 quantized_model = quantize_pt2(model, inputs) 841 842 exported_program = export_to_edge(quantized_model, inputs).exported_program() 843 844 # By default, the quantized linear op should have constant scalar attributes. 845 self.assertTargetCountsEqual( 846 exported_program.graph_module, 847 [ 848 # One quantized linear op. 849 (exir_ops.edge.cadence.quantized_linear.default, 1), 850 # No per tensor quantized linear ops. 851 (exir_ops.edge.cadence.quantized_linear.per_tensor, 0), 852 # Three aten.full ops. 853 (exir_ops.edge.aten.full.default, 3), 854 ], 855 ) 856 857 # Apply replacement pass. 858 p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() 859 graph_after_passes = p(exported_program.graph_module) 860 self.assertIsNotNone(graph_after_passes) 861 gm = dead_code_elimination_pass(graph_after_passes.graph_module).graph_module 862 863 # By default, the quantized linear op should have constant scalar attributes. 864 self.assertTargetCountsEqual( 865 gm, 866 [ 867 # No default quantized linear op. 868 (exir_ops.edge.cadence.quantized_linear.default, 0), 869 # The default quantized linear op will be replaced with quantized_linear.per_tensor. 870 (exir_ops.edge.cadence.quantized_linear.per_tensor, 1), 871 # No aten.full ops. 872 (exir_ops.edge.aten.full.default, 0), 873 ], 874 ) 875 876 @torch.no_grad() 877 def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_args( 878 self, 879 in_features: int = 16, 880 out_features: int = 16, 881 ): 882 # Tensors - these will be inputs to graph. 883 x = torch.randn([1, in_features]) 884 885 inputs = (x,) 886 model = torch.nn.Linear(in_features=in_features, out_features=out_features) 887 quantized_model = quantize_pt2(model, inputs) 888 889 exported_program = export_to_edge(quantized_model, inputs).exported_program() 890 891 # By default, the quantized linear op should have constant scalar attributes. 892 self.assertTargetCountsEqual( 893 exported_program.graph_module, 894 [ 895 # One quantized linear op. 896 (exir_ops.edge.cadence.quantized_linear.default, 1), 897 # No per tensor quantized linear ops. 898 (exir_ops.edge.cadence.quantized_linear.per_tensor, 0), 899 # Three aten.full ops. 900 (exir_ops.edge.aten.full.default, 3), 901 ], 902 ) 903 904 for node in exported_program.graph_module.graph.nodes: 905 # Replace the `shape` argument for aten.full op with a tuple. 906 if node.target == exir_ops.edge.aten.full.default: 907 node.args = (tuple(node.args[0]), node.args[1]) 908 909 # Apply replacement pass. 910 p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass() 911 graph_after_passes = p(exported_program.graph_module) 912 self.assertIsNotNone(graph_after_passes) 913 gm = dead_code_elimination_pass(graph_after_passes.graph_module).graph_module 914 915 # By default, the quantized linear op should have constant scalar attributes. 916 self.assertTargetCountsEqual( 917 gm, 918 [ 919 # No default quantized linear op. 920 (exir_ops.edge.cadence.quantized_linear.default, 0), 921 # The default quantized linear op will be replaced with quantized_linear.per_tensor. 922 (exir_ops.edge.cadence.quantized_linear.per_tensor, 1), 923 # No aten.full ops. 924 (exir_ops.edge.aten.full.default, 0), 925 ], 926 ) 927 928 @torch.no_grad() 929 def test_replace_conv1d_with_linear(self): 930 class Conv(torch.nn.Module): 931 def __init__(self, in_features: int, out_features: int, kernel_size: int): 932 super().__init__() 933 self.conv1d = torch.nn.Conv1d(in_features, out_features, kernel_size) 934 935 def forward(self, x): 936 return self.conv1d(x) 937 938 model_conv1d = Conv(96, 192, 7) 939 x = torch.randn(1, 96, 7) 940 graph_module = ( 941 export_to_edge(model_conv1d, (x,)).exported_program().graph_module 942 ) 943 944 # First, replace the aten convolution with a cadence.convolution op 945 p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() 946 temp_graph = p1(graph_module).graph_module 947 self.assertIsNotNone(temp_graph) 948 949 p2 = ReplaceTrivialConvWithLinear() 950 graph_after_passes = p2(temp_graph).graph_module 951 952 # Assert that conv1d is trivially converted to linear 953 self.assertEqual( 954 count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 955 ) 956 self.assertEqual( 957 count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 958 ) 959 self.assertEqual( 960 count_node(graph_after_passes, exir_ops.edge.aten.linear.default) 961 + count_node( 962 graph_after_passes, exir_ops.edge.cadence.fully_connected.default 963 ), 964 1, 965 ) 966 967 @torch.no_grad() 968 def test_replace_conv2d_with_linear(self): 969 class Conv(torch.nn.Module): 970 def __init__(self, in_features: int, out_features: int, kernel_size: int): 971 super().__init__() 972 self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) 973 974 def forward(self, x): 975 return self.conv2d(x) 976 977 model_conv2d = Conv(96, 192, 7) 978 x = torch.randn(1, 96, 7, 7) 979 graph_module = ( 980 export_to_edge(model_conv2d, (x,)).exported_program().graph_module 981 ) 982 983 # First, replace the aten convolution with a cadence.convolution op 984 p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass() 985 temp_graph = p1(graph_module).graph_module 986 self.assertIsNotNone(temp_graph) 987 988 p2 = ReplaceTrivialConvWithLinear() 989 graph_after_passes = p2(temp_graph).graph_module 990 991 # Assert that conv2d is trivially converted to linear 992 self.assertEqual( 993 count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 994 ) 995 self.assertEqual( 996 count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0 997 ) 998 self.assertEqual( 999 count_node(graph_after_passes, exir_ops.edge.aten.linear.default) 1000 + count_node( 1001 graph_after_passes, exir_ops.edge.cadence.fully_connected.default 1002 ), 1003 1, 1004 ) 1005 1006 @torch.no_grad() 1007 def test_replace_conv2d_with_im2row_and_linear(self): 1008 class Conv(torch.nn.Module): 1009 def __init__(self, in_features: int, out_features: int, kernel_size: int): 1010 super().__init__() 1011 self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size) 1012 1013 def forward(self, x): 1014 return self.conv2d(x) 1015 1016 model_conv2d = Conv(96, 192, 7) 1017 x = torch.randn(1, 96, 47, 37) 1018 graph_module = ( 1019 compiler.export_to_cadence(model_conv2d, (x,)) 1020 .exported_program() 1021 .graph_module 1022 ) 1023 1024 p = ReplaceConvWithIm2RowAndLinear() 1025 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 1026 1027 # Assert that the convolution is converted to im2row + linear 1028 self.assertEqual( 1029 count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0 1030 ) 1031 self.assertEqual( 1032 count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 1 1033 ) 1034 self.assertEqual( 1035 count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1 1036 ) 1037 1038 @parameterized.expand( 1039 [ 1040 [(3, 1, 5), 1, 0], 1041 [(3, 4, 1), 2, -1], 1042 ] 1043 ) 1044 @torch.no_grad() 1045 def test_replace_select_with_view(self, shape: Tuple[int], dim: int, index: int): 1046 class Select(torch.nn.Module): 1047 def forward(self, x): 1048 return x.select(dim, index) 1049 1050 x = torch.randn(shape) 1051 graph_module = export_to_edge(Select(), (x,)).exported_program().graph_module 1052 1053 p = ReplaceSelectWithViewOpPass() 1054 1055 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 1056 1057 # Assert that select op was replaced with view op 1058 self.assertEqual( 1059 count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0 1060 ) 1061 self.assertEqual( 1062 count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 1063 ) 1064 1065 @parameterized.expand( 1066 [ 1067 [(2, 1, 3, 1), 1, 3, torch.float32], 1068 [(2, 1, 5), 1, 0, torch.int64], 1069 [(3, 1, 5), 0, 1, torch.int64], 1070 ] 1071 ) 1072 @torch.no_grad() 1073 def test_replace_nop_transpose_with_view( 1074 self, 1075 shape: Tuple[int], 1076 dim0: int, 1077 dim1: int, 1078 dtype: torch.dtype = torch.float32, 1079 ): 1080 class Transpose(torch.nn.Module): 1081 def forward(self, x): 1082 return x.transpose(dim0, dim1) 1083 1084 _max_value = 127 1085 x = (torch.rand(shape) * _max_value).to(dtype=dtype) 1086 graph_module = export_to_edge(Transpose(), (x,)).exported_program().graph_module 1087 1088 p = ReplaceNopTransposeOrPermuteWithViewPass() 1089 1090 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 1091 1092 # Assert that transpose op was removed, and a view op was placed instead 1093 self.assertEqual( 1094 count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 1095 ) 1096 self.assertEqual( 1097 count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 1098 ) 1099 1100 @parameterized.expand( 1101 [ 1102 # permutations that can be replaced by view 1103 [(3, 1, 3, 1, 4), (0, 2, 4, 1, 3), torch.float32], 1104 [(1, 3, 4), (1, 2, 0), torch.float32], 1105 ] 1106 ) 1107 @torch.no_grad() 1108 def test_replace_nop_permute_with_view(self, input_shape, dims, dtype): 1109 class Permute(torch.nn.Module): 1110 def forward(self, x): 1111 return torch.permute(x, dims) 1112 1113 x = torch.randn(input_shape).to(dtype=dtype) 1114 graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module 1115 1116 p = ReplaceNopTransposeOrPermuteWithViewPass() 1117 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 1118 1119 # Assert that permute op was removed, and a view op was placed instead 1120 self.assertEqual( 1121 count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 1122 ) 1123 self.assertEqual( 1124 count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1 1125 ) 1126 1127 @parameterized.expand( 1128 [ 1129 # permutations replaced by transpose 1130 [(3, 4), [1, 0], torch.float32], 1131 [(3, 4, 6), (0, 2, 1), torch.float32], 1132 ] 1133 ) 1134 @torch.no_grad() 1135 def test_replace_permute_with_transpose(self, input_shape, dims, dtype): 1136 class Permute(torch.nn.Module): 1137 def forward(self, x): 1138 return torch.permute(x, dims) 1139 1140 x = torch.randn(input_shape).to(dtype=dtype) 1141 graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module 1142 1143 p = ReplacePermuteWithTransposePass() 1144 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 1145 1146 # Assert that permute op was replaced by a transpose op 1147 self.assertEqual( 1148 count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 1149 ) 1150 self.assertEqual( 1151 count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1 1152 ) 1153 1154 @parameterized.expand( 1155 [ 1156 # permutations replaced by transpose 1157 [(3, 4), [0, 1], torch.float32], 1158 ] 1159 ) 1160 @torch.no_grad() 1161 def test_replace_permute_with_transpose_nop(self, input_shape, dims, dtype): 1162 class Permute(torch.nn.Module): 1163 def forward(self, x): 1164 return torch.permute(x, dims) 1165 1166 x = torch.randn(input_shape).to(dtype=dtype) 1167 graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module 1168 1169 p = ReplacePermuteWithTransposePass() 1170 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 1171 1172 # Assert that permute op was replaced by a transpose op 1173 self.assertEqual( 1174 count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0 1175 ) 1176 self.assertEqual( 1177 count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0 1178 ) 1179 1180 def test_replace_aten_linalg_vector_norm_with_cadence_linalg_vector_norm(self): 1181 class LinalgVectorNorm(torch.nn.Module): 1182 def forward(self, x: torch.Tensor): 1183 return torch.linalg.vector_norm(x) 1184 1185 x = torch.randn(32) 1186 1187 graph_module = ( 1188 export_to_edge(LinalgVectorNorm(), (x,)).exported_program().graph_module 1189 ) 1190 1191 p = ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass() 1192 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 1193 1194 # Assert that aten.linalg_vector_norm op was replaced by a 1195 # cadence.linalg_vector_norm op 1196 self.assertEqual( 1197 count_node( 1198 graph_after_passes, 1199 exir_ops.edge.aten.linalg_vector_norm.default, 1200 ), 1201 0, 1202 ) 1203 self.assertEqual( 1204 count_node( 1205 graph_after_passes, exir_ops.edge.cadence.linalg_vector_norm.default 1206 ), 1207 1, 1208 ) 1209 1210 1211class TestReplaceIm2rowWithViewPass(unittest.TestCase): 1212 def test_no_replacement_for_conv(self): 1213 # Create a graph with a single im2row node. 1214 x = torch.randn(1, 3, 224, 224) 1215 pad_value = torch.randn(1) 1216 channels_last = False 1217 gm = single_op_builder( 1218 placeholders=(x, pad_value), 1219 op=exir_ops.edge.cadence.im2row.default, 1220 args=(x, (2, 2), (1, 1), (0, 0), (1, 1), pad_value, channels_last), 1221 ) 1222 # Check if graph module is valid by running exportpass on it. 1223 gm = ExportPass().call(gm).graph_module 1224 self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) 1225 self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) 1226 1227 # Apply replacement pass. 1228 p = ReplaceIm2RowWithViewPass() 1229 gm_after_replacement = p.call(gm).graph_module 1230 # Check that no replacement was made. 1231 self.assertEqual( 1232 count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 1233 ) 1234 self.assertEqual( 1235 count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 1236 ) 1237 1238 def test_no_replace_for_dilation(self): 1239 # Create a graph with a single im2row node. 1240 x = torch.randn(1, 3, 5, 7) 1241 pad_value = torch.randn(1) 1242 channels_last = False 1243 gm = single_op_builder( 1244 placeholders=(x, pad_value), 1245 op=exir_ops.edge.cadence.im2row.default, 1246 args=(x, (3, 4), (2, 2), (0, 0), (1, 1), pad_value, channels_last), 1247 ) 1248 # Check if graph module is valid by running exportpass on it. 1249 gm = ExportPass().call(gm).graph_module 1250 self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) 1251 self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) 1252 1253 # Apply replacement pass. 1254 p = ReplaceIm2RowWithViewPass() 1255 gm_after_replacement = p.call(gm).graph_module 1256 self.assertEqual( 1257 count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1 1258 ) 1259 self.assertEqual( 1260 count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0 1261 ) 1262 1263 def test_replace_linear_like_conv(self): 1264 # Create a graph with a single im2row node. 1265 in_h, in_w = 13, 15 1266 x = torch.randn(1, 3, in_h, in_w) 1267 pad_value = torch.randn(1) 1268 channels_last = False 1269 gm = single_op_builder( 1270 placeholders=(x, pad_value), 1271 op=exir_ops.edge.cadence.im2row.default, 1272 args=(x, (in_h, in_w), (1, 1), (0, 0), (1, 1), pad_value, channels_last), 1273 ) 1274 # Check if graph module is valid by running exportpass on it. 1275 gm = ExportPass().call(gm).graph_module 1276 self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1) 1277 self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0) 1278 1279 # Apply replacement pass. 1280 p = ReplaceIm2RowWithViewPass() 1281 gm_after_replacement = p.call(gm).graph_module 1282 # In this test, the kernel width/height is the same as the input width/height. 1283 self.assertEqual( 1284 count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 0 1285 ) 1286 self.assertEqual( 1287 count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 1 1288 ) 1289 1290 1291class TestForceChannelLastForConvPass(unittest.TestCase): 1292 def create_conv1d_graphmodule( 1293 self, channels_last: Optional[bool] = None 1294 ) -> torch.fx.GraphModule: 1295 """Helper to create a convolution node. 1296 1297 convolution( 1298 Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding," 1299 int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)" 1300 """ 1301 if channels_last: 1302 x = torch.randn(1, 224, 3) 1303 w = torch.randn(16, 16, 3) 1304 else: 1305 x = torch.randn(1, 3, 224) 1306 w = torch.randn(16, 3, 16) 1307 b = torch.randn(16) 1308 args = (x, w, b, (2, 2), (1, 1), (0, 0), 1) 1309 if channels_last is not None: 1310 args = args + (channels_last,) 1311 return single_op_builder( 1312 placeholders=(x, w, b), 1313 op=exir_ops.edge.cadence.convolution.default, 1314 args=args, 1315 ) 1316 1317 def test_conv1d_default_channel_last(self): 1318 # Create a graph with a single convolution node. 1319 # Check if graph module is valid by running exportpass on it. 1320 gm = self.create_conv1d_graphmodule() 1321 gm = ExportPass().call(gm).graph_module 1322 self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) 1323 self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0) 1324 1325 # Apply replacement pass. 1326 p = ForceChannelLastForConvPass() 1327 gm_after_replacement = p.call(gm).graph_module 1328 # Check that no replacement was made. 1329 self.assertEqual( 1330 count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), 1331 1, 1332 ) 1333 self.assertEqual( 1334 count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), 1335 # Two transposes are added, one for the input and one for the output. 1336 3, 1337 ) 1338 for node in gm_after_replacement.graph.nodes: 1339 if node.target != exir_ops.edge.cadence.convolution.default: 1340 continue 1341 # Check that the channel_last argument is set to True. 1342 self.assertEqual(len(node.args), 8, f"{node=}") 1343 self.assertTrue(node.args[7]) 1344 1345 def test_conv1d_no_transpose_if_already_channel_last(self): 1346 gm = self.create_conv1d_graphmodule(channels_last=True) 1347 gm = ExportPass().call(gm).graph_module 1348 self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) 1349 1350 # Apply replacement pass. 1351 p = ForceChannelLastForConvPass() 1352 gm_after_replacement = p.call(gm).graph_module 1353 # Check that no replacement was made. 1354 self.assertEqual( 1355 count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), 1356 1, 1357 ) 1358 self.assertEqual( 1359 count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int), 1360 0, 1361 ) 1362 for node in gm_after_replacement.graph.nodes: 1363 if node.target != exir_ops.edge.cadence.convolution.default: 1364 continue 1365 # Check that the channel_last argument is set to True. 1366 self.assertEqual(len(node.args), 8, f"{node=}") 1367 self.assertTrue(node.args[7]) 1368 1369 def create_convolution_graph_module( 1370 self, channels_last: Optional[bool] = None 1371 ) -> torch.fx.GraphModule: 1372 """Helper to create a convolution node. 1373 1374 convolution( 1375 Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding," 1376 int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)" 1377 """ 1378 if channels_last: 1379 x = torch.randn(1, 224, 224, 3) 1380 w = torch.randn(16, 16, 16, 3) 1381 else: 1382 x = torch.randn(1, 3, 224, 224) 1383 w = torch.randn(16, 3, 16, 16) 1384 b = torch.randn(16) 1385 args = (x, w, b, (2, 2), (1, 1), (0, 0), 1) 1386 if channels_last is not None: 1387 args = args + (channels_last,) 1388 return single_op_builder( 1389 placeholders=(x, w, b), 1390 op=exir_ops.edge.cadence.convolution.default, 1391 args=args, 1392 ) 1393 1394 def test_convolution_default_channel_last(self): 1395 # Create a graph with a single convolution node. 1396 # Check if graph module is valid by running exportpass on it. 1397 gm = self.create_convolution_graph_module() 1398 gm = ExportPass().call(gm).graph_module 1399 self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) 1400 self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) 1401 1402 # Apply replacement pass. 1403 p = ForceChannelLastForConvPass() 1404 gm_after_replacement = p.call(gm).graph_module 1405 # Check that no replacement was made. 1406 self.assertEqual( 1407 count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), 1408 1, 1409 ) 1410 self.assertEqual( 1411 count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), 1412 # Three permutes are added, two for the input/weights and one for the output. 1413 3, 1414 ) 1415 for node in gm_after_replacement.graph.nodes: 1416 if node.target != exir_ops.edge.cadence.convolution.default: 1417 continue 1418 # Check that the channel_last argument is set to True. 1419 self.assertEqual(len(node.args), 8, f"{node=}") 1420 self.assertTrue(node.args[7]) 1421 1422 def test_no_transpose_if_already_channel_last(self): 1423 gm = self.create_convolution_graph_module(channels_last=True) 1424 gm = ExportPass().call(gm).graph_module 1425 self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1) 1426 1427 # Apply replacement pass. 1428 p = ForceChannelLastForConvPass() 1429 gm_after_replacement = p.call(gm).graph_module 1430 # Check that no replacement was made. 1431 self.assertEqual( 1432 count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default), 1433 1, 1434 ) 1435 self.assertEqual( 1436 count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), 1437 0, 1438 ) 1439 for node in gm_after_replacement.graph.nodes: 1440 if node.target != exir_ops.edge.cadence.convolution.default: 1441 continue 1442 # Check that the channel_last argument is set to True. 1443 self.assertEqual(len(node.args), 8, f"{node=}") 1444 self.assertTrue(node.args[7]) 1445 1446 def create_quantized_convolution_graph_module( 1447 self, channels_last: Optional[bool] = None 1448 ) -> torch.fx.GraphModule: 1449 """Helper to create a quantized conv node. 1450 1451 quantized_conv( 1452 Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding, 1453 int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point, 1454 Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier, 1455 Tensor out_shift, bool channel_last=False) -> (Tensor Z)" 1456 """ 1457 if channels_last: 1458 x = torch.randn(1, 224, 56, 3) 1459 w = torch.randn(16, 16, 16, 3) 1460 else: 1461 x = torch.randn(1, 3, 224, 56) 1462 w = torch.randn(16, 3, 16, 16) 1463 b = torch.randn(16) 1464 stride = (2, 2) 1465 padding = (0, 0) 1466 dilation = (1, 1) 1467 groups = 1 1468 input_zero_point = 0 1469 w_zero_point = torch.randn(1) 1470 b_scale = torch.randn(1) 1471 out_scale = 1 1472 out_zero_point = 0 1473 out_multiplier = torch.randn(1) 1474 out_shift = torch.randn(1) 1475 args = ( 1476 x, 1477 w, 1478 b, 1479 stride, 1480 padding, 1481 dilation, 1482 groups, 1483 input_zero_point, 1484 w_zero_point, 1485 b_scale, 1486 out_scale, 1487 out_zero_point, 1488 out_multiplier, 1489 out_shift, 1490 ) 1491 if channels_last is not None: 1492 args = args + (channels_last,) 1493 return single_op_builder( 1494 placeholders=(x, w, b, w_zero_point, b_scale, out_multiplier, out_shift), 1495 op=exir_ops.edge.cadence.quantized_conv.default, 1496 args=args, 1497 ) 1498 1499 def test_quantized_convolution_default_channel_last(self): 1500 # Create a graph with a single convolution node. 1501 gm = self.create_quantized_convolution_graph_module() 1502 self.assertEqual( 1503 count_node(gm, exir_ops.edge.cadence.quantized_conv.default), 1 1504 ) 1505 self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) 1506 1507 # Apply replacement pass. 1508 p = ForceChannelLastForConvPass() 1509 gm_after_replacement = p.call(gm).graph_module 1510 # Check that no replacement was made. 1511 self.assertEqual( 1512 count_node( 1513 gm_after_replacement, exir_ops.edge.cadence.quantized_conv.default 1514 ), 1515 1, 1516 ) 1517 # Three permutes are added, two for the input/weights and one for the output. 1518 self.assertEqual( 1519 count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default), 1520 3, 1521 ) 1522 for node in gm_after_replacement.graph.nodes: 1523 if node.target != exir_ops.edge.cadence.quantized_conv.default: 1524 continue 1525 # Check that the channel_last argument is set to True. 1526 self.assertEqual(len(node.args), 15, f"{node=}") 1527 self.assertTrue(node.args[14]) 1528 1529 def test_no_transpose_if_already_quantized_conv_channel_last(self): 1530 # Create a graph with a single im2row node. 1531 gm = self.create_quantized_convolution_graph_module(channels_last=True) 1532 # Check if graph module is valid by running exportpass on it. 1533 gm = ExportPass().call(gm).graph_module 1534 self.assertEqual( 1535 count_node(gm, exir_ops.edge.cadence.quantized_conv.default), 1 1536 ) 1537 1538 # Apply replacement pass. 1539 p = ForceChannelLastForConvPass() 1540 gm_after_replacement = p.call(gm).graph_module 1541 # Check that no replacement was made. 1542 self.assertEqual( 1543 count_node( 1544 gm_after_replacement, exir_ops.edge.cadence.quantized_conv.default 1545 ), 1546 1, 1547 ) 1548 self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0) 1549 for node in gm_after_replacement.graph.nodes: 1550 if node.target != exir_ops.edge.cadence.quantized_conv.default: 1551 continue 1552 # Check that the channel_last argument is set to True. 1553 self.assertEqual(len(node.args), 15, f"{node=}") 1554 self.assertTrue(node.args[14]) 1555 1556 1557class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase): 1558 def create_slice_graph( 1559 self, 1560 input_shape: Sequence[int], 1561 slice_dim: int, 1562 slice_begin: Optional[int] = None, 1563 slice_end: Optional[int] = None, 1564 ) -> torch.fx.GraphModule: 1565 x = torch.randn(*input_shape) 1566 return single_op_builder( 1567 placeholders=(x,), 1568 op=exir_ops.edge.aten.slice_copy.Tensor, 1569 args=(x, slice_dim, slice_begin, slice_end), 1570 ) 1571 1572 def test_slice_no_transpose_if_already_outermost(self): 1573 # Create a graph with a single slice node. 1574 gm = self.create_slice_graph((3, 224, 224), 0, 1, 2) 1575 # Check if graph module is valid by running exportpass on it. 1576 gm = ExportPass().call(gm).graph_module 1577 self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) 1578 1579 # Apply replacement pass. 1580 gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module 1581 1582 # Assert that no transpose ops were added. 1583 self.assertEqual( 1584 count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), 1585 0, 1586 ) 1587 1588 def test_slice_no_transpose_if_outermost_dimensions_are_one(self): 1589 # Create a graph with a single slice node on second outermost dimension. 1590 gm = self.create_slice_graph((1, 3, 4, 6), 1, 1, 2) 1591 # Check if graph module is valid by running exportpass on it. 1592 gm = ExportPass().call(gm).graph_module 1593 self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) 1594 1595 # Apply replacement pass. 1596 gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module 1597 1598 # Assert that no transpose ops were added. The slice is on the second 1599 # outermost dimension, but the outermost dimension is already 1. 1600 self.assertEqual( 1601 count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), 1602 0, 1603 ) 1604 1605 def test_slice_insert_transpose(self): 1606 # Create a graph with a single slice node. 1607 gm = self.create_slice_graph((1, 3, 4, 6), 2, 1, 2) 1608 # Check if graph module is valid by running exportpass on it. 1609 gm = ExportPass().call(gm).graph_module 1610 self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1) 1611 1612 # Apply replacement pass. 1613 gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module 1614 1615 # Assert that there are two transpose ops added. 1616 self.assertEqual( 1617 count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), 1618 2, 1619 ) 1620 1621 def create_cat_graph( 1622 self, 1623 input_shapes: Sequence[Sequence[int]], 1624 cat_dim: int = 0, 1625 ) -> torch.fx.GraphModule: 1626 input_tensors = tuple(torch.randn(s) for s in input_shapes) 1627 return single_op_builder( 1628 placeholders=input_tensors, 1629 op=exir_ops.edge.aten.cat.default, 1630 args=(input_tensors, cat_dim), 1631 ) 1632 1633 def test_cat_no_transpose_if_already_outermost(self): 1634 # Create a graph with a single slice node on second outermost dimension. 1635 gm = self.create_cat_graph(input_shapes=((1, 3, 5), (2, 3, 5)), cat_dim=0) 1636 # Check if graph module is valid by running exportpass on it. 1637 gm = ExportPass().call(gm).graph_module 1638 self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) 1639 1640 # Apply replacement pass. 1641 gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module 1642 1643 # Assert that no transpose ops were added. The slice is on the second 1644 # outermost dimension, but the outermost dimension is already 1. 1645 self.assertEqual( 1646 count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), 1647 0, 1648 ) 1649 1650 def test_cat_no_transpose_if_outermost_dimensions_are_one(self): 1651 # Create a graph with a single slice node on second outermost dimension. 1652 gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1) 1653 # Check if graph module is valid by running exportpass on it. 1654 gm = ExportPass().call(gm).graph_module 1655 self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) 1656 1657 # Apply replacement pass. 1658 gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module 1659 1660 # Assert that no transpose ops were added. The slice is on the second 1661 # outermost dimension, but the outermost dimension is already 1. 1662 self.assertEqual( 1663 count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), 1664 0, 1665 ) 1666 1667 def test_cat_insert_transpose(self): 1668 # Create a graph with a single slice node on second outermost dimension. 1669 gm = self.create_cat_graph( 1670 input_shapes=((1, 1, 3, 5), (1, 1, 3, 3)), cat_dim=-1 1671 ) 1672 # Check if graph module is valid by running exportpass on it. 1673 gm = ExportPass().call(gm).graph_module 1674 self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1) 1675 1676 # Apply replacement pass. 1677 gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module 1678 1679 # Assert that transpose ops were added to make cat on outermost dimension. 1680 self.assertEqual( 1681 count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int), 1682 3, 1683 ) 1684