xref: /aosp_15_r20/external/executorch/backends/cadence/aot/tests/test_remove_ops_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3
4import unittest
5from typing import cast, Tuple
6
7import executorch.backends.cadence.aot.ops_registrations  # noqa
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from executorch.backends.cadence.aot import compiler
12from executorch.backends.cadence.aot.compiler import export_to_edge
13
14from executorch.backends.cadence.aot.pass_utils import count_node
15from executorch.backends.cadence.aot.quantizer.quantizer import CadenceQuantizer
16from executorch.backends.cadence.aot.remove_ops import (
17    RemoveAliasCopyOpPass,
18    RemoveCloneOpPass,
19    RemoveContiguousOpPass,
20    RemoveDetachCopyPass,
21    RemoveNopAddOpPass,
22    RemoveNopExpandOpPass,
23    RemoveNopLinalgVectorNormOpPass,
24    RemoveNopMulOpPass,
25    RemoveNopSelectOpPass,
26    RemoveNopSliceOrViewOpPass,
27    RemovePermutesAroundElementwiseOps,
28    RemoveToOpsPass,
29    RemoveZeroSizedCatArgsPass,
30    RemoveZeroSizedConstantPadNd,
31)
32from executorch.exir.dialects._ops import ops as exir_ops
33from parameterized.parameterized import parameterized
34from pyre_extensions import none_throws
35
36from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
37
38from torch.export import export_for_training
39from torch.fx.passes.infra.pass_base import PassResult
40
41
42class TestRemoveOpsPasses(unittest.TestCase):
43    @parameterized.expand(
44        [
45            [(1, 2, 3)],
46        ]
47    )
48    @torch.no_grad()
49    def test_remove_to_ops(self, shape: Tuple[int]):
50        class M(torch.nn.Module):
51            def forward(self, x: torch.Tensor):
52                return exir_ops.edge.aten.to(x, dtype=torch.float32)
53
54        model = M()
55        x = torch.randn(shape)
56        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
57        p = RemoveToOpsPass()
58
59        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
60
61        self.assertEqual(
62            count_node(graph_after_passes, exir_ops.edge.aten.to.dtype),
63            0,
64        )
65
66        self.assertEqual(
67            count_node(graph_after_passes, exir_ops.edge.aten.to.dtype_layout),
68            0,
69        )
70
71    @parameterized.expand(
72        [
73            [(7, 6, 5)],
74            [(7, 6)],
75            [(7,)],
76        ]
77    )
78    @torch.no_grad()
79    def test_remove_nop_add_op_pass(self, shape: Tuple[int]):
80        class FullX(torch.nn.Module):
81            def forward(self, t: torch.Tensor):
82                return torch.add(torch.full(shape, 0), t)
83
84        class FullY(torch.nn.Module):
85            def forward(self, t: torch.Tensor):
86                return torch.add(t, torch.full(shape, 0))
87
88        model = FullX()
89        t = torch.full(shape, 3)
90        graph_module = export_to_edge(model, (t,)).exported_program().graph_module
91
92        p = RemoveNopAddOpPass()
93
94        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
95        graph_module.print_readable()
96        self.assertEqual(
97            count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
98            0,
99        )
100
101        model = FullY()
102        graph_module = export_to_edge(model, (t,)).exported_program().graph_module
103
104        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
105
106        self.assertEqual(
107            count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
108            0,
109        )
110
111    @parameterized.expand(
112        [
113            [(7, 6, 5)],
114            [(7, 6)],
115            [(7,)],
116        ]
117    )
118    @torch.no_grad()
119    def test_remove_nop_mul_op_pass(self, shape: Tuple[int]):
120        class FullX(torch.nn.Module):
121            def forward(self, t: torch.Tensor):
122                return torch.mul(torch.full(shape, 0), t)
123
124        class FullY(torch.nn.Module):
125            def forward(self, t: torch.Tensor):
126                return torch.mul(t, torch.full(shape, 0))
127
128        model = FullX()
129        t = torch.full(shape, 3)
130        graph_module = export_to_edge(model, (t,)).exported_program().graph_module
131
132        p = RemoveNopMulOpPass()
133
134        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
135        graph_module.print_readable()
136        self.assertEqual(
137            count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
138            0,
139        )
140
141        model = FullY()
142        graph_module = export_to_edge(model, (t,)).exported_program().graph_module
143
144        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
145
146        self.assertEqual(
147            count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
148            0,
149        )
150
151    @parameterized.expand(
152        [
153            [(1, 2, 3)],
154        ]
155    )
156    @torch.no_grad()
157    def test_remove_alias_copy(self, shape: Tuple[int]):
158        class M(torch.nn.Module):
159            def forward(self, x: torch.Tensor):
160                return exir_ops.edge.aten.alias_copy(x)
161
162        model = M()
163        x = torch.randn(shape)
164        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
165
166        p = RemoveAliasCopyOpPass()
167
168        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
169
170        self.assertEqual(
171            count_node(graph_after_passes, exir_ops.edge.aten.alias_copy.default),
172            0,
173        )
174
175    @parameterized.expand(
176        [
177            [(1, 2, 3)],
178        ]
179    )
180    @torch.no_grad()
181    def test_remove_detach_copy(self, shape: Tuple[int]):
182        # aten::detach is converted to aten::alias_copy after functionalization & decomposition.
183        class M(torch.nn.Module):
184            def forward(self, x: torch.Tensor):
185                return exir_ops.edge.aten.detach_copy(x)
186
187        model = M()
188        x = torch.randn(shape)
189        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
190
191        p = RemoveDetachCopyPass()
192
193        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
194
195        self.assertEqual(
196            count_node(graph_after_passes, exir_ops.edge.aten.detach_copy.default),
197            0,
198        )
199
200    @parameterized.expand(
201        [
202            [(1, 2, 3), (0, 0)],
203        ]
204    )
205    @torch.no_grad()
206    def test_remove_zero_sized_constant_pad_nd(
207        self, shape: Tuple[int], padding: Tuple[int]
208    ):
209        # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
210        class Padding(torch.nn.Module):
211            def __init__(self):
212                super().__init__()
213                self.padding = padding
214
215            def forward(self, x: torch.Tensor):
216                return F.pad(x, self.padding)
217
218        model = Padding()
219        x = torch.randn(shape)
220        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
221
222        p = RemoveZeroSizedConstantPadNd()
223
224        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
225
226        self.assertEqual(
227            count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default),
228            0,
229        )
230
231    def test_remove_expand(self):
232        class Expand(torch.nn.Module):
233            def forward(self, x):
234                return torch.ops.aten.expand_copy(x, [2, 3, 5])
235
236        x = torch.ones(2, 3, 5)
237        p = RemoveNopExpandOpPass()
238        graph_module = export_to_edge(Expand(), (x,)).exported_program().graph_module
239        graph_module = p(graph_module).graph_module
240        # Assert that expand op is optimized away, since it is a nop
241        self.assertEqual(
242            count_node(graph_module, exir_ops.edge.aten.expand_copy.default), 0
243        )
244
245    def test_remove_zero_arg_cat(self):
246        class Cat(torch.nn.Module):
247            def forward(self, x, y):
248                return torch.ops.aten.cat((x, y), 0)
249
250        x = torch.ones(1, 0, 3, 5)
251        y = torch.ones(2, 0, 3, 5)
252        graph_module = (
253            compiler.export_to_cadence(Cat(), (x, y)).exported_program().graph_module
254        )
255        # Assert that cat op is optimized away, since it concatenates
256        # two zero-sized tensors
257        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 0)
258
259    def test_remove_single_arg_cat(self):
260        class Cat(torch.nn.Module):
261            def forward(self, x, y):
262                z = torch.ones(0, 5)
263                # z is an empty tensor, and concatenation of x with z will
264                # be x. So we can safely eliminate the following cat op.
265                x1 = torch.ops.aten.cat((x, z))
266                x2 = torch.add(x1, 2.4, 3.1)
267                y1 = torch.add(y, 1, 2)
268                return torch.add(x2, y1)
269
270        x = torch.ones(3, 5)
271        y = torch.ones(3, 5)
272        graph_module = export_to_edge(Cat(), (x, y)).exported_program().graph_module
273        new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module
274        new_graph_module.graph.eliminate_dead_code()
275        # Assert that x1 is optimized away
276        self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)
277
278    def test_remove_zero_sized_cat(self):
279        class Cat(torch.nn.Module):
280            def __init__(self, dim: int):
281                super().__init__()
282                self.dim = dim
283
284            def forward(self, tensors):
285                return torch.cat(tensors, self.dim)
286
287        shapes, dim, dtype, _max = [(1, 0, 3), (2, 0, 3)], 0, torch.float32, 127
288
289        in_tensors = [(torch.rand(shape) * _max).to(dtype=dtype) for shape in shapes]
290
291        model = Cat(dim)
292        graph_module = (
293            export_to_edge(model, (in_tensors,)).exported_program().graph_module
294        )
295        new_graph_module = RemoveZeroSizedCatArgsPass()(graph_module).graph_module
296        new_graph_module.graph.eliminate_dead_code()
297        self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)
298
299    def test_remove_clone(self):
300        class Clone(torch.nn.Module):
301            def forward(self, x, y):
302                t1 = x.clone()
303                t2 = y.clone()
304                return t1 + t2
305
306        x = torch.ones(3, 5)
307        y = torch.ones(3, 5)
308        graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module
309        new_graph_module = RemoveCloneOpPass()(graph_module).graph_module
310        new_graph_module.graph.eliminate_dead_code()
311        # Assert that t1 and t2 are optimized away
312        self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0)
313
314    def test_remove_contiguous(self):
315        class Contiguous(torch.nn.Module):
316            def forward(self, x, y):
317                t1 = x.contiguous()
318                t2 = y.contiguous()
319                return t1 + t2
320
321        x = torch.ones(3, 5)
322        y = torch.ones(3, 5)
323        graph_module = (
324            export_to_edge(Contiguous(), (x, y)).exported_program().graph_module
325        )
326        new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module
327        new_graph_module.graph.eliminate_dead_code()
328        # Assert that t1 and t2 are optimized away
329        self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0)
330
331    @parameterized.expand(
332        [
333            [(3, 5), [3, 5]],
334            [(1,), [-1]],
335        ]
336    )
337    @torch.no_grad()
338    def test_remove_nop_view(self, shape, new_shape):
339        class View(torch.nn.Module):
340            def __init__(self, new_shape):
341                super().__init__()
342                self.new_shape = new_shape
343
344            def forward(self, x: torch.Tensor):
345                return x.view(self.new_shape)
346
347        model = View(new_shape)
348        x = torch.randn(shape)
349        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
350        p = RemoveNopSliceOrViewOpPass()
351        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
352        graph_after_passes.graph.eliminate_dead_code()
353        # Assert that view op was removed
354        self.assertEqual(
355            count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0
356        )
357
358    def test_remove_nop_slice(self):
359        class Slice(torch.nn.Module):
360            def forward(self, x):
361                return torch.slice_copy(x, dim=0, start=0, step=1)
362
363        x = torch.ones(3, 5)
364        model = Slice()
365        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
366        p = RemoveNopSliceOrViewOpPass()
367        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
368        graph_after_passes.graph.eliminate_dead_code()
369        # Assert that slice op was removed
370        self.assertEqual(
371            count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
372        )
373
374    def test_remove_nop_select(self):
375        class SelectFeasible1(torch.nn.Module):
376            def forward(self, x):
377                y = x.select(0, 0)
378                z = y.view([1, 5, 6])
379                return z
380
381        x = torch.ones(1, 5, 6)
382        graph_module = (
383            export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module
384        )
385        self.assertEqual(
386            count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
387        )
388        graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
389        # Assert that select op was removed
390        self.assertEqual(
391            count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
392        )
393
394        class SelectFeasible2(torch.nn.Module):
395            def forward(self, x, y):
396                x = x.select(0, 0)
397                z = x + y
398                return z
399
400        x = torch.ones(1, 5, 6)
401        y = torch.ones(1, 5, 6)
402        graph_module = (
403            export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module
404        )
405        self.assertEqual(
406            count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
407        )
408        graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
409        # Assert that select op was removed
410        self.assertEqual(
411            count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
412        )
413
414        class SelectFeasible3(torch.nn.Module):
415            def forward(self, x, y):
416                x = x.select(0, 0)
417                z = x * y
418                return z
419
420        x = torch.ones(1, 5, 6)
421        y = torch.ones(1, 5, 6)
422        graph_module = (
423            export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module
424        )
425        self.assertEqual(
426            count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
427        )
428        graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
429        # Assert that select op was removed
430        self.assertEqual(
431            count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
432        )
433
434        class SelectFeasible4(torch.nn.Module):
435            def forward(self, x, y):
436                x = x.select(0, 0)
437                z = x / y
438                return z
439
440        x = torch.ones(1, 5, 6)
441        y = torch.ones(1, 5, 6)
442        graph_module = (
443            export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module
444        )
445        self.assertEqual(
446            count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
447        )
448        graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
449        # Assert that select op was removed
450        self.assertEqual(
451            count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
452        )
453
454    def test_remove_nop_quant_dequant(self):
455        class M(torch.nn.Module):
456            def __init__(self):
457                super(M, self).__init__()
458                self.linear = torch.nn.Linear(6, 12, bias=False)
459
460            def forward(self, x):
461                x = self.linear(x)
462                return x
463
464        inp = torch.randn(2, 8, 1, 6)
465
466        # Run the standard quant/convert steps, but without fusing
467        # this leaves two redundant quant/dequant pairs to test with
468        quantizer = CadenceQuantizer()
469        model_exp = export_for_training(M(), (inp,)).module()
470        prepared_model = prepare_pt2e(model_exp, quantizer)
471        prepared_model(inp)
472        converted_model = convert_pt2e(prepared_model)
473
474        graph_module = (
475            compiler.export_to_cadence(
476                converted_model,
477                (inp,),
478            )
479            .exported_program()
480            .graph_module
481        )
482
483        # Expect all quantize ops to be removed by the pass
484        self.assertEqual(
485            count_node(graph_module, exir_ops.edge.cadence.quantize_per_tensor.default),
486            0,
487        )
488
489        # Expect 1 dequantize op for the weights
490        self.assertEqual(
491            count_node(
492                graph_module, exir_ops.edge.cadence.dequantize_per_tensor.default
493            ),
494            1,
495        )
496
497    def test_remove_nop_aten_linalg_vector_norm(self):
498        class LinalgVectorNorm(torch.nn.Module):
499            def forward(self, x: torch.Tensor):
500                return torch.linalg.vector_norm(x, 2, [0, 1], True)
501
502        model = LinalgVectorNorm()
503        x = torch.randn([1, 1, 128])
504        inputs = (x,)
505
506        graph_module = (
507            compiler.export_to_edge(
508                model,
509                inputs,
510            )
511            .exported_program()
512            .graph_module
513        )
514
515        graph_module = none_throws(
516            RemoveNopLinalgVectorNormOpPass()(graph_module)
517        ).graph_module
518
519        # Expect the linalg_vector_norm op to be removed by the pass
520        self.assertEqual(
521            count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default)
522            + count_node(
523                graph_module, exir_ops.edge.cadence.linalg_vector_norm.default
524            ),
525            0,
526        )
527
528    def test_remove_permutes_around_elemwise_ops_add(self) -> None:
529        class M(torch.nn.Module):
530            def __init__(self):
531                super().__init__()
532                self.conv = nn.Conv2d(8, 8, 1, bias=False)
533
534            def forward(self, x):
535                x = self.conv(x)
536                x = torch.permute(x, [0, 3, 1, 2])
537                x = torch.add(x, x)
538                x = torch.permute(x, [0, 2, 3, 1])
539                x = self.conv(x)
540                return x
541
542        inputs = (torch.randn(1, 8, 4, 4),)
543        graph_module = export_to_edge(M(), inputs).exported_program().graph_module
544        p = RemovePermutesAroundElementwiseOps()
545        graph_module = cast(PassResult, p(graph_module)).graph_module
546
547        self.assertEqual(
548            count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0
549        )
550
551    def test_remove_permutes_around_elemwise_ops_add_mean(self) -> None:
552        class M(torch.nn.Module):
553            def __init__(self):
554                super().__init__()
555                self.conv2d = nn.Conv2d(8, 8, 1)
556
557            def forward(self, x, y):
558                x = self.conv2d(x)
559                y = self.conv2d(y)
560                x = torch.permute(x, [0, 3, 1, 2])
561                y = torch.permute(y, [0, 3, 1, 2])
562                z = torch.add(x, y)
563                z = torch.mean(z, dim=[-1, -3], keepdim=True)
564                z = torch.permute(z, [0, 2, 3, 1])
565                z = self.conv2d(z)
566                return z
567
568        inputs = (torch.randn(1, 8, 4, 4), torch.randn(1, 8, 4, 4))
569        graph_module = export_to_edge(M(), inputs).exported_program().graph_module
570        p = RemovePermutesAroundElementwiseOps()
571        graph_module = cast(PassResult, p(graph_module)).graph_module
572
573        self.assertEqual(
574            count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0
575        )
576
577        # verify that mean was updated correctly
578        mean = [
579            n
580            for n in graph_module.graph.nodes
581            if n.target == exir_ops.edge.aten.mean.dim
582        ][0]
583        self.assertEqual(mean.args[1], [2, 3])
584
585    def test_remove_permutes_around_elemwise_ops_mul(self) -> None:
586        class M(torch.nn.Module):
587            def forward(self, x, y):
588                x = torch.slice_copy(x, 0, 0, 1)
589                x = torch.permute(x, [0, 3, 1, 2])
590                y = torch.permute(y, [0, 3, 1, 2])
591                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
592                    x, 1.5, 0, 0, 255, torch.uint8
593                )
594                z = x * y
595                z = torch.ops.quantized_decomposed.quantize_per_tensor(
596                    z, 2.5, 0, 0, 255, torch.uint8
597                )
598                z = torch.permute(z, [0, 2, 3, 1])
599                z = torch.unsqueeze_copy(z, 0)
600                return z
601
602        inputs = (torch.randn(2, 4, 4, 8), torch.randn(2, 4, 4, 8))
603        graph_module = export_to_edge(M(), inputs).exported_program().graph_module
604
605        p = RemovePermutesAroundElementwiseOps()
606        graph_module = cast(PassResult, p(graph_module)).graph_module
607
608        self.assertEqual(
609            count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 0
610        )
611
612    def test_remove_permutes_around_elemwise_ops_double_permutes(self) -> None:
613        class M(torch.nn.Module):
614            def forward(self, x, y):
615                x = torch.slice_copy(x, 0, 0, 1)
616                x = torch.permute(x, [0, 3, 1, 2])
617                x = torch.permute(x, [0, 3, 1, 2])
618                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
619                    x, 1.5, 0, 0, 255, torch.uint8
620                )
621                y = torch.permute(y, [0, 3, 1, 2])
622                y = torch.ops.quantized_decomposed.dequantize_per_tensor(
623                    y, 1.5, 0, 0, 255, torch.uint8
624                )
625                z = torch.cat((x, y), 1)
626                z = torch.ops.quantized_decomposed.quantize_per_tensor(
627                    z, 2.5, 0, 0, 255, torch.uint8
628                )
629                z = torch.permute(z, [0, 2, 3, 1])
630                z = torch.permute(z, [0, 2, 3, 1])
631                z = torch.unsqueeze_copy(z, 0)
632                return z
633
634        inputs = (torch.randn(2, 4, 4, 8), torch.randn(1, 8, 4, 4))
635        graph_module = export_to_edge(M(), inputs).exported_program().graph_module
636        p = RemovePermutesAroundElementwiseOps()
637        graph_module = cast(PassResult, p(graph_module)).graph_module
638
639        # Expect 2 permutes to remain, one on input x and one on output z
640        self.assertEqual(
641            count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
642        )
643
644        # verify that cat was updated correctly
645        cat = [
646            n
647            for n in graph_module.graph.nodes
648            if n.target == exir_ops.edge.aten.cat.default
649        ][0]
650        self.assertEqual(cat.args[1], 3)
651
652    def test_remove_permutes_around_elemwise_ops_noop(self) -> None:
653        class M(torch.nn.Module):
654            def __init__(self):
655                super().__init__()
656                self.conv = nn.Conv2d(8, 8, 1, bias=False)
657
658            def forward(self, x):
659                x = self.conv(x)
660                x = torch.permute(x, [0, 2, 3, 1])
661                x = torch.add(x, x)
662                x = torch.permute(x, [0, 3, 1, 2])
663                x = self.conv(x)
664                return x
665
666        inputs = (torch.randn(1, 8, 4, 4),)
667        graph_module = export_to_edge(M(), inputs).exported_program().graph_module
668        p = RemovePermutesAroundElementwiseOps()
669        graph_module = cast(PassResult, p(graph_module)).graph_module
670
671        # Ensure no permutes were removed, since the dimensions don't fit the expected pattern
672        self.assertEqual(
673            count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 2
674        )
675