1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. 2 3 4import unittest 5from typing import cast, Tuple 6 7import executorch.backends.cadence.aot.ops_registrations # noqa 8import torch 9import torch.nn as nn 10import torch.nn.functional as F 11from executorch.backends.cadence.aot import compiler 12from executorch.backends.cadence.aot.compiler import export_to_edge 13 14from executorch.backends.cadence.aot.pass_utils import count_node 15from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer 16from executorch.backends.cadence.aot.remove_ops import ( 17 RemoveAliasCopyOpPass, 18 RemoveCloneOpPass, 19 RemoveContiguousOpPass, 20 RemoveDetachCopyPass, 21 RemoveNopAddOpPass, 22 RemoveNopExpandOpPass, 23 RemoveNopLinalgVectorNormOpPass, 24 RemoveNopMulOpPass, 25 RemoveNopSelectOpPass, 26 RemoveNopSliceOrViewOpPass, 27 RemovePermutesAroundElementwiseOps, 28 RemoveToOpsPass, 29 RemoveZeroSizedCatArgsPass, 30 RemoveZeroSizedConstantPadNd, 31) 32from executorch.exir.dialects._ops import ops as exir_ops 33from parameterized.parameterized import parameterized 34from pyre_extensions import none_throws 35 36from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e 37 38from torch.export import export_for_training 39from torch.fx.passes.infra.pass_base import PassResult 40 41 42class TestRemoveOpsPasses(unittest.TestCase): 43 @parameterized.expand( 44 [ 45 [(1, 2, 3)], 46 ] 47 ) 48 @torch.no_grad() 49 def test_remove_to_ops(self, shape: Tuple[int]): 50 class M(torch.nn.Module): 51 def forward(self, x: torch.Tensor): 52 return exir_ops.edge.aten.to(x, dtype=torch.float32) 53 54 model = M() 55 x = torch.randn(shape) 56 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 57 p = RemoveToOpsPass() 58 59 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 60 61 self.assertEqual( 62 count_node(graph_after_passes, exir_ops.edge.aten.to.dtype), 63 0, 64 ) 65 66 self.assertEqual( 67 count_node(graph_after_passes, exir_ops.edge.aten.to.dtype_layout), 68 0, 69 ) 70 71 @parameterized.expand( 72 [ 73 [(7, 6, 5)], 74 [(7, 6)], 75 [(7,)], 76 ] 77 ) 78 @torch.no_grad() 79 def test_remove_nop_add_op_pass(self, shape: Tuple[int]): 80 class FullX(torch.nn.Module): 81 def forward(self, t: torch.Tensor): 82 return torch.add(torch.full(shape, 0), t) 83 84 class FullY(torch.nn.Module): 85 def forward(self, t: torch.Tensor): 86 return torch.add(t, torch.full(shape, 0)) 87 88 model = FullX() 89 t = torch.full(shape, 3) 90 graph_module = export_to_edge(model, (t,)).exported_program().graph_module 91 92 p = RemoveNopAddOpPass() 93 94 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 95 graph_module.print_readable() 96 self.assertEqual( 97 count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), 98 0, 99 ) 100 101 model = FullY() 102 graph_module = export_to_edge(model, (t,)).exported_program().graph_module 103 104 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 105 106 self.assertEqual( 107 count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor), 108 0, 109 ) 110 111 @parameterized.expand( 112 [ 113 [(7, 6, 5)], 114 [(7, 6)], 115 [(7,)], 116 ] 117 ) 118 @torch.no_grad() 119 def test_remove_nop_mul_op_pass(self, shape: Tuple[int]): 120 class FullX(torch.nn.Module): 121 def forward(self, t: torch.Tensor): 122 return torch.mul(torch.full(shape, 0), t) 123 124 class FullY(torch.nn.Module): 125 def forward(self, t: torch.Tensor): 126 return torch.mul(t, torch.full(shape, 0)) 127 128 model = FullX() 129 t = torch.full(shape, 3) 130 graph_module = export_to_edge(model, (t,)).exported_program().graph_module 131 132 p = RemoveNopMulOpPass() 133 134 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 135 graph_module.print_readable() 136 self.assertEqual( 137 count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), 138 0, 139 ) 140 141 model = FullY() 142 graph_module = export_to_edge(model, (t,)).exported_program().graph_module 143 144 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 145 146 self.assertEqual( 147 count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor), 148 0, 149 ) 150 151 @parameterized.expand( 152 [ 153 [(1, 2, 3)], 154 ] 155 ) 156 @torch.no_grad() 157 def test_remove_alias_copy(self, shape: Tuple[int]): 158 class M(torch.nn.Module): 159 def forward(self, x: torch.Tensor): 160 return exir_ops.edge.aten.alias_copy(x) 161 162 model = M() 163 x = torch.randn(shape) 164 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 165 166 p = RemoveAliasCopyOpPass() 167 168 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 169 170 self.assertEqual( 171 count_node(graph_after_passes, exir_ops.edge.aten.alias_copy.default), 172 0, 173 ) 174 175 @parameterized.expand( 176 [ 177 [(1, 2, 3)], 178 ] 179 ) 180 @torch.no_grad() 181 def test_remove_detach_copy(self, shape: Tuple[int]): 182 # aten::detach is converted to aten::alias_copy after functionalization & decomposition. 183 class M(torch.nn.Module): 184 def forward(self, x: torch.Tensor): 185 return exir_ops.edge.aten.detach_copy(x) 186 187 model = M() 188 x = torch.randn(shape) 189 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 190 191 p = RemoveDetachCopyPass() 192 193 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 194 195 self.assertEqual( 196 count_node(graph_after_passes, exir_ops.edge.aten.detach_copy.default), 197 0, 198 ) 199 200 @parameterized.expand( 201 [ 202 [(1, 2, 3), (0, 0)], 203 ] 204 ) 205 @torch.no_grad() 206 def test_remove_zero_sized_constant_pad_nd( 207 self, shape: Tuple[int], padding: Tuple[int] 208 ): 209 # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition. 210 class Padding(torch.nn.Module): 211 def __init__(self): 212 super().__init__() 213 self.padding = padding 214 215 def forward(self, x: torch.Tensor): 216 return F.pad(x, self.padding) 217 218 model = Padding() 219 x = torch.randn(shape) 220 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 221 222 p = RemoveZeroSizedConstantPadNd() 223 224 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 225 226 self.assertEqual( 227 count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default), 228 0, 229 ) 230 231 def test_remove_expand(self): 232 class Expand(torch.nn.Module): 233 def forward(self, x): 234 return torch.ops.aten.expand_copy(x, [2, 3, 5]) 235 236 x = torch.ones(2, 3, 5) 237 p = RemoveNopExpandOpPass() 238 graph_module = export_to_edge(Expand(), (x,)).exported_program().graph_module 239 graph_module = p(graph_module).graph_module 240 # Assert that expand op is optimized away, since it is a nop 241 self.assertEqual( 242 count_node(graph_module, exir_ops.edge.aten.expand_copy.default), 0 243 ) 244 245 def test_remove_zero_arg_cat(self): 246 class Cat(torch.nn.Module): 247 def forward(self, x, y): 248 return torch.ops.aten.cat((x, y), 0) 249 250 x = torch.ones(1, 0, 3, 5) 251 y = torch.ones(2, 0, 3, 5) 252 graph_module = ( 253 compiler.export_to_cadence(Cat(), (x, y)).exported_program().graph_module 254 ) 255 # Assert that cat op is optimized away, since it concatenates 256 # two zero-sized tensors 257 self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0) 258 259 def test_remove_single_arg_cat(self): 260 class Cat(torch.nn.Module): 261 def forward(self, x, y): 262 z = torch.ones(0, 5) 263 # z is an empty tensor, and concatenation of x with z will 264 # be x. So we can safely eliminate the following cat op. 265 x1 = torch.ops.aten.cat((x, z)) 266 x2 = torch.add(x1, 2.4, 3.1) 267 y1 = torch.add(y, 1, 2) 268 return torch.add(x2, y1) 269 270 x = torch.ones(3, 5) 271 y = torch.ones(3, 5) 272 graph_module = export_to_edge(Cat(), (x, y)).exported_program().graph_module 273 new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module 274 new_graph_module.graph.eliminate_dead_code() 275 # Assert that x1 is optimized away 276 self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0) 277 278 def test_remove_zero_sized_cat(self): 279 class Cat(torch.nn.Module): 280 def __init__(self, dim: int): 281 super().__init__() 282 self.dim = dim 283 284 def forward(self, tensors): 285 return torch.cat(tensors, self.dim) 286 287 shapes, dim, dtype, _max = [(1, 0, 3), (2, 0, 3)], 0, torch.float32, 127 288 289 in_tensors = [(torch.rand(shape) * _max).to(dtype=dtype) for shape in shapes] 290 291 model = Cat(dim) 292 graph_module = ( 293 export_to_edge(model, (in_tensors,)).exported_program().graph_module 294 ) 295 new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module 296 new_graph_module.graph.eliminate_dead_code() 297 self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0) 298 299 def test_remove_clone(self): 300 class Clone(torch.nn.Module): 301 def forward(self, x, y): 302 t1 = x.clone() 303 t2 = y.clone() 304 return t1 + t2 305 306 x = torch.ones(3, 5) 307 y = torch.ones(3, 5) 308 graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module 309 new_graph_module = RemoveCloneOpPass()(graph_module).graph_module 310 new_graph_module.graph.eliminate_dead_code() 311 # Assert that t1 and t2 are optimized away 312 self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0) 313 314 def test_remove_contiguous(self): 315 class Contiguous(torch.nn.Module): 316 def forward(self, x, y): 317 t1 = x.contiguous() 318 t2 = y.contiguous() 319 return t1 + t2 320 321 x = torch.ones(3, 5) 322 y = torch.ones(3, 5) 323 graph_module = ( 324 export_to_edge(Contiguous(), (x, y)).exported_program().graph_module 325 ) 326 new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module 327 new_graph_module.graph.eliminate_dead_code() 328 # Assert that t1 and t2 are optimized away 329 self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0) 330 331 @parameterized.expand( 332 [ 333 [(3, 5), [3, 5]], 334 [(1,), [-1]], 335 ] 336 ) 337 @torch.no_grad() 338 def test_remove_nop_view(self, shape, new_shape): 339 class View(torch.nn.Module): 340 def __init__(self, new_shape): 341 super().__init__() 342 self.new_shape = new_shape 343 344 def forward(self, x: torch.Tensor): 345 return x.view(self.new_shape) 346 347 model = View(new_shape) 348 x = torch.randn(shape) 349 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 350 p = RemoveNopSliceOrViewOpPass() 351 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 352 graph_after_passes.graph.eliminate_dead_code() 353 # Assert that view op was removed 354 self.assertEqual( 355 count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0 356 ) 357 358 def test_remove_nop_slice(self): 359 class Slice(torch.nn.Module): 360 def forward(self, x): 361 return torch.slice_copy(x, dim=0, start=0, step=1) 362 363 x = torch.ones(3, 5) 364 model = Slice() 365 graph_module = export_to_edge(model, (x,)).exported_program().graph_module 366 p = RemoveNopSliceOrViewOpPass() 367 graph_after_passes = cast(PassResult, p(graph_module)).graph_module 368 graph_after_passes.graph.eliminate_dead_code() 369 # Assert that slice op was removed 370 self.assertEqual( 371 count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0 372 ) 373 374 def test_remove_nop_select(self): 375 class SelectFeasible1(torch.nn.Module): 376 def forward(self, x): 377 y = x.select(0, 0) 378 z = y.view([1, 5, 6]) 379 return z 380 381 x = torch.ones(1, 5, 6) 382 graph_module = ( 383 export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module 384 ) 385 self.assertEqual( 386 count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 387 ) 388 graph_module = RemoveNopSelectOpPass()(graph_module).graph_module 389 # Assert that select op was removed 390 self.assertEqual( 391 count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 392 ) 393 394 class SelectFeasible2(torch.nn.Module): 395 def forward(self, x, y): 396 x = x.select(0, 0) 397 z = x + y 398 return z 399 400 x = torch.ones(1, 5, 6) 401 y = torch.ones(1, 5, 6) 402 graph_module = ( 403 export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module 404 ) 405 self.assertEqual( 406 count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 407 ) 408 graph_module = RemoveNopSelectOpPass()(graph_module).graph_module 409 # Assert that select op was removed 410 self.assertEqual( 411 count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 412 ) 413 414 class SelectFeasible3(torch.nn.Module): 415 def forward(self, x, y): 416 x = x.select(0, 0) 417 z = x * y 418 return z 419 420 x = torch.ones(1, 5, 6) 421 y = torch.ones(1, 5, 6) 422 graph_module = ( 423 export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module 424 ) 425 self.assertEqual( 426 count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 427 ) 428 graph_module = RemoveNopSelectOpPass()(graph_module).graph_module 429 # Assert that select op was removed 430 self.assertEqual( 431 count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 432 ) 433 434 class SelectFeasible4(torch.nn.Module): 435 def forward(self, x, y): 436 x = x.select(0, 0) 437 z = x / y 438 return z 439 440 x = torch.ones(1, 5, 6) 441 y = torch.ones(1, 5, 6) 442 graph_module = ( 443 export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module 444 ) 445 self.assertEqual( 446 count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1 447 ) 448 graph_module = RemoveNopSelectOpPass()(graph_module).graph_module 449 # Assert that select op was removed 450 self.assertEqual( 451 count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0 452 ) 453 454 def test_remove_nop_quant_dequant(self): 455 class M(torch.nn.Module): 456 def __init__(self): 457 super(M, self).__init__() 458 self.linear = torch.nn.Linear(6, 12, bias=False) 459 460 def forward(self, x): 461 x = self.linear(x) 462 return x 463 464 inp = torch.randn(2, 8, 1, 6) 465 466 # Run the standard quant/convert steps, but without fusing 467 # this leaves two redundant quant/dequant pairs to test with 468 quantizer = CadenceQuantizer() 469 model_exp = export_for_training(M(), (inp,)).module() 470 prepared_model = prepare_pt2e(model_exp, quantizer) 471 prepared_model(inp) 472 converted_model = convert_pt2e(prepared_model) 473 474 graph_module = ( 475 compiler.export_to_cadence( 476 converted_model, 477 (inp,), 478 ) 479 .exported_program() 480 .graph_module 481 ) 482 483 # Expect all quantize ops to be removed by the pass 484 self.assertEqual( 485 count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default), 486 0, 487 ) 488 489 # Expect 1 dequantize op for the weights 490 self.assertEqual( 491 count_node( 492 graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default 493 ), 494 1, 495 ) 496 497 def test_remove_nop_aten_linalg_vector_norm(self): 498 class LinalgVectorNorm(torch.nn.Module): 499 def forward(self, x: torch.Tensor): 500 return torch.linalg.vector_norm(x, 2, [0, 1], True) 501 502 model = LinalgVectorNorm() 503 x = torch.randn([1, 1, 128]) 504 inputs = (x,) 505 506 graph_module = ( 507 compiler.export_to_edge( 508 model, 509 inputs, 510 ) 511 .exported_program() 512 .graph_module 513 ) 514 515 graph_module = none_throws( 516 RemoveNopLinalgVectorNormOpPass()(graph_module) 517 ).graph_module 518 519 # Expect the linalg_vector_norm op to be removed by the pass 520 self.assertEqual( 521 count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default) 522 + count_node( 523 graph_module, exir_ops.edge.cadence.linalg_vector_norm.default 524 ), 525 0, 526 ) 527 528 def test_remove_permutes_around_elemwise_ops_add(self) -> None: 529 class M(torch.nn.Module): 530 def __init__(self): 531 super().__init__() 532 self.conv = nn.Conv2d(8, 8, 1, bias=False) 533 534 def forward(self, x): 535 x = self.conv(x) 536 x = torch.permute(x, [0, 3, 1, 2]) 537 x = torch.add(x, x) 538 x = torch.permute(x, [0, 2, 3, 1]) 539 x = self.conv(x) 540 return x 541 542 inputs = (torch.randn(1, 8, 4, 4),) 543 graph_module = export_to_edge(M(), inputs).exported_program().graph_module 544 p = RemovePermutesAroundElementwiseOps() 545 graph_module = cast(PassResult, p(graph_module)).graph_module 546 547 self.assertEqual( 548 count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 549 ) 550 551 def test_remove_permutes_around_elemwise_ops_add_mean(self) -> None: 552 class M(torch.nn.Module): 553 def __init__(self): 554 super().__init__() 555 self.conv2d = nn.Conv2d(8, 8, 1) 556 557 def forward(self, x, y): 558 x = self.conv2d(x) 559 y = self.conv2d(y) 560 x = torch.permute(x, [0, 3, 1, 2]) 561 y = torch.permute(y, [0, 3, 1, 2]) 562 z = torch.add(x, y) 563 z = torch.mean(z, dim=[-1, -3], keepdim=True) 564 z = torch.permute(z, [0, 2, 3, 1]) 565 z = self.conv2d(z) 566 return z 567 568 inputs = (torch.randn(1, 8, 4, 4), torch.randn(1, 8, 4, 4)) 569 graph_module = export_to_edge(M(), inputs).exported_program().graph_module 570 p = RemovePermutesAroundElementwiseOps() 571 graph_module = cast(PassResult, p(graph_module)).graph_module 572 573 self.assertEqual( 574 count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 575 ) 576 577 # verify that mean was updated correctly 578 mean = [ 579 n 580 for n in graph_module.graph.nodes 581 if n.target == exir_ops.edge.aten.mean.dim 582 ][0] 583 self.assertEqual(mean.args[1], [2, 3]) 584 585 def test_remove_permutes_around_elemwise_ops_mul(self) -> None: 586 class M(torch.nn.Module): 587 def forward(self, x, y): 588 x = torch.slice_copy(x, 0, 0, 1) 589 x = torch.permute(x, [0, 3, 1, 2]) 590 y = torch.permute(y, [0, 3, 1, 2]) 591 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 592 x, 1.5, 0, 0, 255, torch.uint8 593 ) 594 z = x * y 595 z = torch.ops.quantized_decomposed.quantize_per_tensor( 596 z, 2.5, 0, 0, 255, torch.uint8 597 ) 598 z = torch.permute(z, [0, 2, 3, 1]) 599 z = torch.unsqueeze_copy(z, 0) 600 return z 601 602 inputs = (torch.randn(2, 4, 4, 8), torch.randn(2, 4, 4, 8)) 603 graph_module = export_to_edge(M(), inputs).exported_program().graph_module 604 605 p = RemovePermutesAroundElementwiseOps() 606 graph_module = cast(PassResult, p(graph_module)).graph_module 607 608 self.assertEqual( 609 count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0 610 ) 611 612 def test_remove_permutes_around_elemwise_ops_double_permutes(self) -> None: 613 class M(torch.nn.Module): 614 def forward(self, x, y): 615 x = torch.slice_copy(x, 0, 0, 1) 616 x = torch.permute(x, [0, 3, 1, 2]) 617 x = torch.permute(x, [0, 3, 1, 2]) 618 x = torch.ops.quantized_decomposed.dequantize_per_tensor( 619 x, 1.5, 0, 0, 255, torch.uint8 620 ) 621 y = torch.permute(y, [0, 3, 1, 2]) 622 y = torch.ops.quantized_decomposed.dequantize_per_tensor( 623 y, 1.5, 0, 0, 255, torch.uint8 624 ) 625 z = torch.cat((x, y), 1) 626 z = torch.ops.quantized_decomposed.quantize_per_tensor( 627 z, 2.5, 0, 0, 255, torch.uint8 628 ) 629 z = torch.permute(z, [0, 2, 3, 1]) 630 z = torch.permute(z, [0, 2, 3, 1]) 631 z = torch.unsqueeze_copy(z, 0) 632 return z 633 634 inputs = (torch.randn(2, 4, 4, 8), torch.randn(1, 8, 4, 4)) 635 graph_module = export_to_edge(M(), inputs).exported_program().graph_module 636 p = RemovePermutesAroundElementwiseOps() 637 graph_module = cast(PassResult, p(graph_module)).graph_module 638 639 # Expect 2 permutes to remain, one on input x and one on output z 640 self.assertEqual( 641 count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2 642 ) 643 644 # verify that cat was updated correctly 645 cat = [ 646 n 647 for n in graph_module.graph.nodes 648 if n.target == exir_ops.edge.aten.cat.default 649 ][0] 650 self.assertEqual(cat.args[1], 3) 651 652 def test_remove_permutes_around_elemwise_ops_noop(self) -> None: 653 class M(torch.nn.Module): 654 def __init__(self): 655 super().__init__() 656 self.conv = nn.Conv2d(8, 8, 1, bias=False) 657 658 def forward(self, x): 659 x = self.conv(x) 660 x = torch.permute(x, [0, 2, 3, 1]) 661 x = torch.add(x, x) 662 x = torch.permute(x, [0, 3, 1, 2]) 663 x = self.conv(x) 664 return x 665 666 inputs = (torch.randn(1, 8, 4, 4),) 667 graph_module = export_to_edge(M(), inputs).exported_program().graph_module 668 p = RemovePermutesAroundElementwiseOps() 669 graph_module = cast(PassResult, p(graph_module)).graph_module 670 671 # Ensure no permutes were removed, since the dimensions don't fit the expected pattern 672 self.assertEqual( 673 count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2 674 ) 675