xref: /aosp_15_r20/external/executorch/backends/cadence/aot/tests/test_replace_ops_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3import unittest
4from typing import Any, Callable, cast, List, Optional, Sequence, Tuple, Union
5
6import torch
7import torch.nn.functional as F
8from executorch.backends.cadence.aot import compiler
9from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
10from executorch.backends.cadence.aot.graph_builder import single_op_builder
11from executorch.backends.cadence.aot.pass_utils import count_node
12from executorch.backends.cadence.aot.replace_ops import (
13    ForceChannelLastForConvPass,
14    MakeSliceAndCatDimOutermostPass,
15    ReplaceAddMMWithLinearPass,
16    ReplaceAtenConvolutionWithJarvisConvolutionPass,
17    ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass,
18    ReplaceConstantPadNdWithSlicePass,
19    ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
20    ReplaceConvWithIm2RowAndLinear,
21    ReplaceFunctionallyEquivalentOpTargets,
22    ReplaceIm2RowWithViewPass,
23    ReplaceLinearWithFullyConnectedOpPass,
24    ReplaceMMWithAddMMPass,
25    ReplaceNopTransposeOrPermuteWithViewPass,
26    ReplacePadWithCatPass,
27    ReplacePermuteWithTransposePass,
28    ReplaceRepeatWithCatPass,
29    ReplaceScalarTensorWithFullPass,
30    ReplaceScalarWithTensorArgPass,
31    ReplaceSelectWithViewOpPass,
32    ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
33    ReplaceSqueezeAndUnsqueezeWithViewPass,
34    ReplaceTCopyWithTransposePass,
35    ReplaceTransposedConvWithLinearPass,
36    ReplaceTrivialConvWithLinear,
37)
38from executorch.exir.dialects._ops import ops as exir_ops
39from executorch.exir.pass_base import ExportPass
40from executorch.exir.passes import dead_code_elimination_pass
41
42from parameterized.parameterized import parameterized
43from torch._ops import OpOverload
44from torch.fx.passes.infra.pass_base import PassResult
45
46
47class TestReplaceOpsPasses(unittest.TestCase):
48    def assertTargetCountEqual(
49        self,
50        graph_module: torch.fx.GraphModule,
51        target: Union[Callable[..., Any], str],
52        expected_count: int,
53    ):
54        """Helper function to check the number of nodes with a given target."""
55        actual_count = count_node(graph_module, target)
56        self.assertEqual(
57            actual_count,
58            expected_count,
59            f"{target} count mismatch for graph {graph_module}",
60        )
61
62    def assertTargetCountsEqual(
63        self,
64        graph_module: torch.fx.GraphModule,
65        targets_and_counts: List[Tuple[Union[Callable[..., Any], str], int]],
66    ):
67        """Helper function to check the number of nodes of all types for a given target."""
68        for target, expected_count in targets_and_counts:
69            self.assertTargetCountEqual(graph_module, target, expected_count)
70
71    @parameterized.expand(
72        [
73            [(3, 5), (0, 0)],
74            [
75                (20, 1, 80),
76                (0, 0),
77            ],
78        ]
79    )
80    @torch.no_grad()
81    def test_replace_constant_pad_nd_with_slice(
82        self, shape: Tuple[int], padding: Tuple[int]
83    ):
84        # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
85        class Padding(torch.nn.Module):
86            def __init__(self):
87                super().__init__()
88                self.padding = padding
89
90            def forward(self, x: torch.Tensor):
91                return F.pad(x, self.padding)
92
93        model = Padding()
94        x = torch.randn(shape)
95        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
96
97        p = ReplaceConstantPadNdWithSlicePass()
98
99        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
100        self.assertEqual(
101            count_node(graph_after_passes, exir_ops.edge.aten.slice.Tensor),
102            1,
103        )
104
105        self.assertEqual(
106            count_node(graph_after_passes, exir_ops.edge.aten.constant_pad_nd.default),
107            0,
108        )
109
110    @parameterized.expand(
111        [
112            [(7, 5, 6), 1.23],
113            [(7, 5), 2],
114        ]
115    )
116    @torch.no_grad()
117    def test_add_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
118        class Add(torch.nn.Module):
119            def forward(self, x):
120                return torch.ops.aten.add.Scalar(x, other)
121
122        model = Add()
123        x = torch.randn(shape)
124        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
125
126        p = ReplaceScalarWithTensorArgPass()
127
128        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
129        self.assertEqual(
130            count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
131            1,
132        )
133
134        self.assertEqual(
135            count_node(graph_after_passes, exir_ops.edge.aten.add.Scalar),
136            0,
137        )
138
139    @parameterized.expand(
140        [
141            [(7, 5, 6), 1.23],
142            [(7, 5), 2],
143            [(10), 42949],
144        ]
145    )
146    @torch.no_grad()
147    def test_sub_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
148        class Sub(torch.nn.Module):
149            def forward(self, x):
150                return torch.ops.aten.sub.Scalar(x, other)
151
152        model = Sub()
153        x = torch.randn(shape)
154        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
155
156        p = ReplaceScalarWithTensorArgPass()
157
158        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
159        self.assertEqual(
160            count_node(graph_after_passes, exir_ops.edge.aten.sub.Tensor),
161            1,
162        )
163
164        self.assertEqual(
165            count_node(graph_after_passes, exir_ops.edge.aten.sub.Scalar),
166            0,
167        )
168
169    @parameterized.expand(
170        [
171            [(7, 5, 6), 1.23],
172            [(7, 5), 2],
173            [(513), 3],
174        ]
175    )
176    @torch.no_grad()
177    def test_mul_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
178        class Mul(torch.nn.Module):
179            def forward(self, x):
180                return torch.ops.aten.mul.Scalar(x, other)
181
182        model = Mul()
183        x = torch.randn(shape)
184        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
185
186        p = ReplaceScalarWithTensorArgPass()
187
188        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
189        self.assertEqual(
190            count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
191            1,
192        )
193
194        self.assertEqual(
195            count_node(graph_after_passes, exir_ops.edge.aten.mul.Scalar),
196            0,
197        )
198
199    @parameterized.expand(
200        [
201            [(7, 5, 6), 1.23],
202            [(7, 5), 2],
203        ]
204    )
205    @torch.no_grad()
206    def test_div_replace_scalar_with_tensor_arg(
207        self,
208        shape: Tuple[int],
209        other: float,
210    ):
211        class Div(torch.nn.Module):
212            def forward(self, x):
213                return torch.ops.aten.div.Scalar(x, other)
214
215        model = Div()
216        x = torch.randn(shape)
217        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
218
219        p = ReplaceScalarWithTensorArgPass()
220
221        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
222        self.assertEqual(
223            count_node(graph_after_passes, exir_ops.edge.aten.div.Tensor),
224            1,
225        )
226
227        self.assertEqual(
228            count_node(graph_after_passes, exir_ops.edge.aten.div.Scalar),
229            0,
230        )
231
232    @parameterized.expand(
233        [
234            [(2, 3, 5, 6)],
235            [(7, 6, 5)],
236            [(4, 4)],
237            [(316)],
238        ]
239    )
240    @torch.no_grad()
241    def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]):
242        model = torch.nn.ReLU()
243        x = torch.randn(shape)
244        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
245        p = ReplaceFunctionallyEquivalentOpTargets()
246
247        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
248        self.assertEqual(
249            count_node(graph_after_passes, exir_ops.edge.aten.relu.default),
250            1,
251        )
252        self.assertEqual(
253            count_node(graph_after_passes, exir_ops.edge.aten.relu_.default),
254            0,
255        )
256
257    @parameterized.expand(
258        [
259            # split the only dimension
260            [(50,), i, 0]
261            for i in range(2, 7)
262        ]
263        + [
264            # split the leading dim
265            [(10, 2, 3), i, 0]
266            for i in range(2, 7)
267        ]
268        + [
269            # split the trailing dim
270            [(3, 3, 6), i, 2]
271            for i in range(2, 6)
272        ]
273        + [
274            # split the dim in the middle
275            [(3, 5, 14, 2, 3), i, 2]
276            for i in range(2, 7)
277        ]
278    )
279    @torch.no_grad()
280    def test_replace_functionally_equivalent_op_targets_unsafe_split(
281        self, shape: Tuple[int], split_size: int, dim: int
282    ):
283        class TensorSplitWithSizes(torch.nn.Module):
284            def __init__(self, split_size: int, dim: int, op: OpOverload):
285                super().__init__()
286                self.split_size = split_size
287                self.dim = dim
288                self.op = op
289
290            def forward(self, x: torch.Tensor):
291                return self.op(x, self.split_size, self.dim)
292
293        x = torch.randn(shape)
294        model = TensorSplitWithSizes(split_size, dim, torch.unsafe_split)
295        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
296        p = ReplaceFunctionallyEquivalentOpTargets()
297
298        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
299        self.assertEqual(
300            count_node(
301                graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default
302            ),
303            1,
304        )
305        self.assertEqual(
306            count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor),
307            0,
308        )
309
310    @parameterized.expand(
311        [
312            [(16, 32)],
313            [(1, 240)],
314            [(4, 16)],
315        ]
316    )
317    @torch.no_grad()
318    def test_replace_t_copy_with_transpose(self, shape: Tuple[int]):
319        class TCopy(torch.nn.Module):
320            def forward(self, x: torch.Tensor):
321                return exir_ops.edge.aten.t_copy(x)
322
323        w = torch.randn(shape)
324        inputs = (w,)
325        p1 = ReplaceTCopyWithTransposePass()
326        p2 = ReplacePermuteWithTransposePass()
327        model = TCopy()
328        graph_module = export_to_edge(model, inputs).exported_program().graph_module
329        graph_after_passes = cast(
330            PassResult, p2(cast(PassResult, p1(graph_module)).graph_module)
331        ).graph_module
332        self.assertEqual(
333            count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int),
334            1,
335        )
336        self.assertEqual(
337            count_node(graph_after_passes, exir_ops.edge.aten.t_copy),
338            0,
339        )
340
341    @parameterized.expand(
342        [
343            [(1, 8, 33), 8, 16, 3],
344            [(1, 8, 33), 8, 16, 5, 2],
345            [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False],
346            # channel last
347            [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True],
348            [(1, 33, 8), 8, 16, 5, 2, 0, 1, False, True, True],
349        ]
350    )
351    @torch.no_grad()
352    def test_replace_transposed_conv_with_linear(
353        self,
354        shape: Tuple[int],
355        in_channels: int,
356        out_channels: int,
357        kernel: int,
358        stride: int = 1,
359        padding: int = 0,
360        dilation: int = 1,
361        depthwise: bool = False,
362        bias: bool = True,
363        channel_last: bool = False,
364    ):
365        class TConv(torch.nn.Module):
366            def __init__(self):
367                super().__init__()
368                self.tconv1d = torch.nn.ConvTranspose1d(
369                    in_channels,
370                    out_channels,
371                    kernel,
372                    stride=stride,
373                    padding=padding,
374                    dilation=dilation,
375                    groups=in_channels if depthwise else 1,
376                    bias=bias,
377                )
378
379            def forward(self, x: torch.Tensor):
380                if channel_last:
381                    x = x.permute([0, 2, 1])
382                x = self.tconv1d(x)
383                if channel_last:
384                    x = x.permute([0, 2, 1])
385                return x
386
387        x = torch.randn(shape)
388        model = TConv()
389        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
390        p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass()
391        p2 = ReplaceTransposedConvWithLinearPass()
392        graph_after_passes = cast(
393            PassResult, p2(cast(PassResult, p1(graph_module)).graph_module)
394        ).graph_module
395        self.assertEqual(
396            count_node(graph_after_passes, exir_ops.edge.aten.linear.default),
397            1,
398        )
399        self.assertEqual(
400            count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
401            0,
402        )
403        self.assertEqual(
404            count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),
405            0,
406        )
407
408    @parameterized.expand(
409        [
410            [(1, 8, 33), 8, 16, 3, 2, 4, 3, False, False, False],
411            # # depthwise
412            [(1, 8, 33), 8, 16, 3, 1, 0, 1, True, False, False],
413            [(1, 8, 33), 8, 16, 3, 2, 4, 3, True, False, False],
414            # channel last (uses a permute op before calling conv1d)
415            [(1, 33, 8), 8, 16, 3, 1, 0, 1, False, False, True],
416            [(1, 33, 8), 8, 16, 3, 2, 4, 3, True, False, True],
417        ]
418    )
419    @torch.no_grad()
420    def test_replace_convolution_optional_args_with_concrete_args(
421        self,
422        shape: Tuple[int],
423        in_channels: int,
424        out_channels: int,
425        kernel: int,
426        stride: int = 1,
427        padding: int = 0,
428        dilation: int = 1,
429        depthwise: bool = False,
430        bias: bool = True,
431        channel_last: bool = False,
432    ):
433        class Conv(torch.nn.Module):
434            def __init__(self):
435                super().__init__()
436                self.conv1d = torch.nn.Conv1d(
437                    in_channels,
438                    out_channels,
439                    kernel,
440                    stride=stride,
441                    padding=padding,
442                    dilation=dilation,
443                    groups=in_channels if depthwise else 1,
444                    bias=bias,
445                )
446
447            def forward(self, x: torch.Tensor):
448                if channel_last:
449                    x = x.permute([0, 2, 1])
450                x = self.conv1d(x)
451                if channel_last:
452                    x = x.permute([0, 2, 1])
453                return x
454
455        x = torch.randn(shape)
456        model = Conv()
457
458        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
459
460        p = ReplaceConvolutionOptionalArgsWithConcreteArgsPass()
461
462        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
463        self.assertEqual(
464            count_node(graph_after_passes, exir_ops.edge.aten.full.default),
465            1,
466        )
467        self.assertEqual(
468            count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
469            1,
470        )
471
472    @parameterized.expand(
473        [
474            [(1, 2, 3), (1, 1)],
475            [
476                (20, 1, 80),
477                (1, 4),
478            ],
479        ]
480    )
481    @torch.no_grad()
482    def test_replace_pad_with_cat(self, shape: Tuple[int], padding: Tuple[int]):
483        # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
484        class Padding(torch.nn.Module):
485            def __init__(self):
486                super().__init__()
487                self.padding = padding
488
489            def forward(self, x: torch.Tensor):
490                return F.pad(x, self.padding)
491
492        model = Padding()
493        x = torch.randn(shape)
494        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
495
496        p = ReplacePadWithCatPass()
497
498        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
499        self.assertEqual(
500            count_node(graph_after_passes, exir_ops.edge.aten.cat.default),
501            1,
502        )
503
504        self.assertEqual(
505            count_node(graph_after_passes, exir_ops.edge.aten.pad.default),
506            0,
507        )
508
509    @torch.no_grad()
510    def test_replace_repeat_with_cat(self):
511        class Repeat(torch.nn.Module):
512            def forward(self, x):
513                x1 = torch.add(x, 2.4, 3.1)
514                return torch.ops.aten.repeat(x1, [1, 2])
515
516        x = torch.ones(3, 5)
517        graph_module = export_to_edge(Repeat(), (x,)).exported_program().graph_module
518
519        p = ReplaceRepeatWithCatPass()
520        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
521        self.assertEqual(
522            count_node(graph_after_passes, exir_ops.edge.aten.cat.default),
523            1,
524        )
525
526        self.assertEqual(
527            count_node(graph_after_passes, exir_ops.edge.aten.repeat.default),
528            0,
529        )
530
531    @parameterized.expand(
532        [
533            # x, mask
534            [(1,)],
535            [(3, 4)],
536            [(7, 8, 3)],
537            [(3, 3, 2, 4)],
538            [(36, 1, 2, 80), (1)],
539            # tests where mask will be broadcasted
540            [(36, 1, 2, 80), (1, 1, 2, 1)],
541            [(36, 2, 8, 4), (36, 1, 1, 4)],
542            [(36, 2, 8, 4), (2, 1, 4)],
543        ]
544    )
545    @torch.no_grad()
546    def test_replace_masked_scalar_tensor_with_full(
547        self,
548        shape: Tuple[int],
549        mask_shape: Union[Tuple[int, ...], None] = None,
550    ):
551        class MaskedFill(torch.nn.Module):
552            def __init__(self, value: float):
553                super().__init__()
554                self.value = value
555
556            def forward(self, x: torch.Tensor, mask: torch.Tensor):
557                return torch.masked_fill(x, mask, self.value)
558
559        x = torch.randn(shape)
560        mask = torch.randn(mask_shape if mask_shape else shape) > 0
561        value = 0.5 * torch.mean(x).item()
562        model = MaskedFill(value)
563        graph_module = export_to_edge(model, (x, mask)).exported_program().graph_module
564
565        p = ReplaceScalarTensorWithFullPass()
566
567        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
568        self.assertEqual(
569            count_node(graph_after_passes, exir_ops.edge.aten.full.default),
570            1,
571        )
572
573        self.assertEqual(
574            count_node(graph_after_passes, exir_ops.edge.aten.where.self),
575            1,
576        )
577
578        self.assertEqual(
579            count_node(graph_after_passes, exir_ops.edge.aten.masked_fill),
580            0,
581        )
582
583    @parameterized.expand(
584        [
585            [(1), 1.5],
586            [(1), 0.0],
587        ]
588    )
589    @torch.no_grad()
590    def test_replace_scalar_tensor_with_full(self, shape: Tuple[int], value: float):
591        class ScalarTensor(torch.nn.Module):
592            def __init__(self, shape: Tuple[int], value: float):
593                super().__init__()
594                self.shape = shape
595                self.value = value
596
597            def forward(self, x: torch.Tensor):
598                return torch.scalar_tensor(value)
599
600        model = ScalarTensor(shape, value)
601        x = torch.randn(shape)
602        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
603
604        p = ReplaceScalarTensorWithFullPass()
605
606        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
607        self.assertEqual(
608            count_node(graph_after_passes, exir_ops.edge.aten.full.default),
609            1,
610        )
611
612        self.assertEqual(
613            count_node(graph_after_passes, exir_ops.edge.aten.scalar_tensor.default),
614            0,
615        )
616
617    @torch.no_grad()
618    def test_replace_linear_with_fully_connected(self):
619        shape, in_features, out_features, bias = (1, 14), 14, 128, False
620        model = torch.nn.Linear(in_features, out_features, bias=bias)
621        x = torch.randn(shape)
622
623        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
624        permute_to_trans_pass = ReplacePermuteWithTransposePass()
625        mm_to_addmm_pass = ReplaceMMWithAddMMPass()
626        add_to_linear_pass = ReplaceAddMMWithLinearPass()
627        linear_to_fullyconnected_pass = ReplaceLinearWithFullyConnectedOpPass()
628        graph_after_passes = linear_to_fullyconnected_pass(
629            add_to_linear_pass(
630                mm_to_addmm_pass(
631                    permute_to_trans_pass(graph_module).graph_module
632                ).graph_module
633            ).graph_module
634        ).graph_module
635        self.assertIsNotNone(graph_after_passes)
636
637        self.assertEqual(
638            count_node(graph_after_passes, exir_ops.edge.aten.full.default),
639            1,
640        )
641
642        self.assertEqual(
643            count_node(
644                graph_after_passes, exir_ops.edge.cadence.fully_connected.default
645            ),
646            1,
647        )
648
649        self.assertEqual(
650            count_node(graph_after_passes, exir_ops.edge.aten.linear),
651            0,
652        )
653
654    @parameterized.expand(
655        [
656            [(4, 16, 256), 256, 512, True],
657            [(7, 17, 12), 12, 34, False],
658        ]
659    )
660    @torch.no_grad()
661    def test_replace_addmm_with_linear(
662        self, shape: Tuple[int], in_features: int, out_features: int, bias: bool
663    ):
664        class AddMM(torch.nn.Module):
665            def __init__(self, alpha: float = 1, beta: float = 1):
666                super().__init__()
667                self.alpha = alpha
668                self.beta = beta
669
670            def forward(self, x: torch.Tensor, y: torch.Tensor, z: torch.Tensor):
671                return torch.addmm(
672                    x, y, z.transpose(1, 0), alpha=self.alpha, beta=self.beta
673                )
674
675        # alpha, beta must be 1 to be 1 to enable ReplaceAddMMWithLinearPass
676        # get_attr will always turn into placeholders and mutable outputs in PT2
677        M, K, N, alpha, beta = 14, 48, 24, 1.0, 1.0
678        x = torch.randn(N)
679        y = torch.randn(M, K)
680        z = torch.randn(N, K)
681
682        # test addmm
683        model = AddMM(alpha=alpha, beta=beta)
684        graph_module = export_to_edge(model, (x, y, z)).exported_program().graph_module
685
686        tp = ReplacePermuteWithTransposePass()
687        ap = ReplaceAddMMWithLinearPass()
688        graph_after_passes = cast(
689            PassResult, ap(cast(PassResult, tp(graph_module)).graph_module)
690        ).graph_module
691        self.assertIsNotNone(graph_after_passes)
692
693        self.assertEqual(
694            count_node(graph_module, exir_ops.edge.aten.addmm.default),
695            1,
696        )
697
698        # Assert that all the aten.addmm nodes are removed.
699        self.assertEqual(
700            count_node(graph_after_passes, exir_ops.edge.aten.linear.default),
701            1,
702        )
703        self.assertEqual(
704            count_node(graph_after_passes, exir_ops.edge.aten.addmm.default),
705            0,
706        )
707
708    @torch.no_grad()
709    def test_replace_mm_with_addmm(self):
710        # The mm ops will be convereted to addmm ops by Jarvis
711        class MM(torch.nn.Module):
712            def __init__(self, K, N):
713                super().__init__()
714                self.K = K
715                self.N = N
716
717            def forward(self, y: torch.Tensor, z: torch.Tensor):
718                return torch.ops.aten.mm(y, z)
719
720        M, K, N = 14, 48, 24
721        y = torch.randn(M, K)
722        z = torch.randn(K, N)
723
724        # test addmm
725        model = MM(K, N)
726        graph_module = export_to_edge(model, (y, z)).exported_program().graph_module
727
728        # First, replace the aten.mm with an aten.addmm op
729        p = ReplaceMMWithAddMMPass()
730        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
731        self.assertIsNotNone(graph_after_passes)
732
733        # Assert that all the aten.mm nodes are removed.
734        self.assertEqual(
735            count_node(graph_after_passes, exir_ops.edge.aten.addmm.default),
736            1,
737        )
738
739        self.assertEqual(
740            count_node(graph_after_passes, exir_ops.edge.aten.mm),
741            0,
742        )
743
744    @parameterized.expand(
745        [
746            # shape
747            [(5, 1, 6, 7)],
748            [(1)],
749            [(4, 3, 2)],
750            # shape, dim to squeeze
751            [(2, 1), 0],
752            [(2, 7, 1, 3), 1],
753            [(2, 1, 3), 2],
754        ]
755    )
756    @torch.no_grad()
757    def test_replace_squeeze_with_view(self, shape: Tuple[int], dim=None):
758        # The squeeze ops will be convereted to view ops by Jarvis
759        class Squeeze(torch.nn.Module):
760            def __init__(self, dim):
761                super().__init__()
762                self.dim = dim
763
764            def forward(self, x: torch.Tensor):
765                if self.dim is None:
766                    return torch.squeeze(x)
767                return torch.squeeze(x, self.dim)
768
769        model = Squeeze(dim)
770        x = torch.randn(shape)
771        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
772
773        # First, replace the aten.squeeze_copy with an aten.view_copy op
774        p = ReplaceSqueezeAndUnsqueezeWithViewPass()
775
776        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
777        self.assertIsNotNone(graph_after_passes)
778
779        # Assert that all the aten.squeeze_copy nodes are removed.
780        self.assertEqual(
781            count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),
782            1,
783        )
784        self.assertEqual(
785            count_node(graph_after_passes, exir_ops.aten.squeeze_copy),
786            0,
787        )
788
789    @parameterized.expand(
790        [
791            # shape, dim to unsqueeze
792            [(5, 6, 7), 0],
793            [(5, 6, 7), -1],
794            [(5, 6, 7), 3],
795            [(5, 6, 7), 2],
796        ]
797    )
798    @torch.no_grad()
799    def test_replace_unsqueeze_with_view(self, shape: Tuple[int], dim: int):
800        class Unsqueeze(torch.nn.Module):
801            def __init__(self, dim):
802                super().__init__()
803                self.dim = dim
804
805            def forward(self, x: torch.Tensor):
806                return torch.unsqueeze(x, self.dim)
807
808        # Test that the pass works for all dims.
809        model = Unsqueeze(dim)
810        x = torch.randn(5, 6, 7)
811        graph_module = export_to_edge(model, (x,)).exported_program().graph_module
812
813        # First, replace the aten.unsqueeze_copy with an aten.view_copy op
814        p = ReplaceSqueezeAndUnsqueezeWithViewPass()
815
816        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
817        self.assertIsNotNone(graph_after_passes)
818
819        # Assert that all the aten.unsqueeze_copy nodes are removed.
820        self.assertEqual(
821            count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default),
822            1,
823        )
824        self.assertEqual(
825            count_node(graph_after_passes, exir_ops.aten.unsqueeze_copy),
826            0,
827        )
828
829    @torch.no_grad()
830    def test_replace_single_element_tensor_arguments_from_full_op_with_scalar(
831        self,
832        in_features: int = 16,
833        out_features: int = 16,
834    ):
835        # Tensors - these will be inputs to graph.
836        x = torch.randn([1, in_features])
837
838        inputs = (x,)
839        model = torch.nn.Linear(in_features=in_features, out_features=out_features)
840        quantized_model = quantize_pt2(model, inputs)
841
842        exported_program = export_to_edge(quantized_model, inputs).exported_program()
843
844        # By default, the quantized linear op should have constant scalar attributes.
845        self.assertTargetCountsEqual(
846            exported_program.graph_module,
847            [
848                # One quantized linear op.
849                (exir_ops.edge.cadence.quantized_linear.default, 1),
850                # No per tensor quantized linear ops.
851                (exir_ops.edge.cadence.quantized_linear.per_tensor, 0),
852                # Three aten.full ops.
853                (exir_ops.edge.aten.full.default, 3),
854            ],
855        )
856
857        # Apply replacement pass.
858        p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass()
859        graph_after_passes = p(exported_program.graph_module)
860        self.assertIsNotNone(graph_after_passes)
861        gm = dead_code_elimination_pass(graph_after_passes.graph_module).graph_module
862
863        # By default, the quantized linear op should have constant scalar attributes.
864        self.assertTargetCountsEqual(
865            gm,
866            [
867                # No default quantized linear op.
868                (exir_ops.edge.cadence.quantized_linear.default, 0),
869                # The default quantized linear op will be replaced with quantized_linear.per_tensor.
870                (exir_ops.edge.cadence.quantized_linear.per_tensor, 1),
871                # No aten.full ops.
872                (exir_ops.edge.aten.full.default, 0),
873            ],
874        )
875
876    @torch.no_grad()
877    def test_replace_single_element_tensor_arguments_from_full_op_with_scalar_tuple_args(
878        self,
879        in_features: int = 16,
880        out_features: int = 16,
881    ):
882        # Tensors - these will be inputs to graph.
883        x = torch.randn([1, in_features])
884
885        inputs = (x,)
886        model = torch.nn.Linear(in_features=in_features, out_features=out_features)
887        quantized_model = quantize_pt2(model, inputs)
888
889        exported_program = export_to_edge(quantized_model, inputs).exported_program()
890
891        # By default, the quantized linear op should have constant scalar attributes.
892        self.assertTargetCountsEqual(
893            exported_program.graph_module,
894            [
895                # One quantized linear op.
896                (exir_ops.edge.cadence.quantized_linear.default, 1),
897                # No per tensor quantized linear ops.
898                (exir_ops.edge.cadence.quantized_linear.per_tensor, 0),
899                # Three aten.full ops.
900                (exir_ops.edge.aten.full.default, 3),
901            ],
902        )
903
904        for node in exported_program.graph_module.graph.nodes:
905            # Replace the `shape` argument for aten.full op with a tuple.
906            if node.target == exir_ops.edge.aten.full.default:
907                node.args = (tuple(node.args[0]), node.args[1])
908
909        # Apply replacement pass.
910        p = ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass()
911        graph_after_passes = p(exported_program.graph_module)
912        self.assertIsNotNone(graph_after_passes)
913        gm = dead_code_elimination_pass(graph_after_passes.graph_module).graph_module
914
915        # By default, the quantized linear op should have constant scalar attributes.
916        self.assertTargetCountsEqual(
917            gm,
918            [
919                # No default quantized linear op.
920                (exir_ops.edge.cadence.quantized_linear.default, 0),
921                # The default quantized linear op will be replaced with quantized_linear.per_tensor.
922                (exir_ops.edge.cadence.quantized_linear.per_tensor, 1),
923                # No aten.full ops.
924                (exir_ops.edge.aten.full.default, 0),
925            ],
926        )
927
928    @torch.no_grad()
929    def test_replace_conv1d_with_linear(self):
930        class Conv(torch.nn.Module):
931            def __init__(self, in_features: int, out_features: int, kernel_size: int):
932                super().__init__()
933                self.conv1d = torch.nn.Conv1d(in_features, out_features, kernel_size)
934
935            def forward(self, x):
936                return self.conv1d(x)
937
938        model_conv1d = Conv(96, 192, 7)
939        x = torch.randn(1, 96, 7)
940        graph_module = (
941            export_to_edge(model_conv1d, (x,)).exported_program().graph_module
942        )
943
944        # First, replace the aten convolution with a cadence.convolution op
945        p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass()
946        temp_graph = p1(graph_module).graph_module
947        self.assertIsNotNone(temp_graph)
948
949        p2 = ReplaceTrivialConvWithLinear()
950        graph_after_passes = p2(temp_graph).graph_module
951
952        # Assert that conv1d is trivially converted to linear
953        self.assertEqual(
954            count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0
955        )
956        self.assertEqual(
957            count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0
958        )
959        self.assertEqual(
960            count_node(graph_after_passes, exir_ops.edge.aten.linear.default)
961            + count_node(
962                graph_after_passes, exir_ops.edge.cadence.fully_connected.default
963            ),
964            1,
965        )
966
967    @torch.no_grad()
968    def test_replace_conv2d_with_linear(self):
969        class Conv(torch.nn.Module):
970            def __init__(self, in_features: int, out_features: int, kernel_size: int):
971                super().__init__()
972                self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size)
973
974            def forward(self, x):
975                return self.conv2d(x)
976
977        model_conv2d = Conv(96, 192, 7)
978        x = torch.randn(1, 96, 7, 7)
979        graph_module = (
980            export_to_edge(model_conv2d, (x,)).exported_program().graph_module
981        )
982
983        # First, replace the aten convolution with a cadence.convolution op
984        p1 = ReplaceAtenConvolutionWithJarvisConvolutionPass()
985        temp_graph = p1(graph_module).graph_module
986        self.assertIsNotNone(temp_graph)
987
988        p2 = ReplaceTrivialConvWithLinear()
989        graph_after_passes = p2(temp_graph).graph_module
990
991        # Assert that conv2d is trivially converted to linear
992        self.assertEqual(
993            count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0
994        )
995        self.assertEqual(
996            count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 0
997        )
998        self.assertEqual(
999            count_node(graph_after_passes, exir_ops.edge.aten.linear.default)
1000            + count_node(
1001                graph_after_passes, exir_ops.edge.cadence.fully_connected.default
1002            ),
1003            1,
1004        )
1005
1006    @torch.no_grad()
1007    def test_replace_conv2d_with_im2row_and_linear(self):
1008        class Conv(torch.nn.Module):
1009            def __init__(self, in_features: int, out_features: int, kernel_size: int):
1010                super().__init__()
1011                self.conv2d = torch.nn.Conv2d(in_features, out_features, kernel_size)
1012
1013            def forward(self, x):
1014                return self.conv2d(x)
1015
1016        model_conv2d = Conv(96, 192, 7)
1017        x = torch.randn(1, 96, 47, 37)
1018        graph_module = (
1019            compiler.export_to_cadence(model_conv2d, (x,))
1020            .exported_program()
1021            .graph_module
1022        )
1023
1024        p = ReplaceConvWithIm2RowAndLinear()
1025        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1026
1027        # Assert that the convolution is converted to im2row + linear
1028        self.assertEqual(
1029            count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default), 0
1030        )
1031        self.assertEqual(
1032            count_node(graph_after_passes, exir_ops.edge.cadence.im2row.default), 1
1033        )
1034        self.assertEqual(
1035            count_node(graph_after_passes, exir_ops.edge.aten.linear.default), 1
1036        )
1037
1038    @parameterized.expand(
1039        [
1040            [(3, 1, 5), 1, 0],
1041            [(3, 4, 1), 2, -1],
1042        ]
1043    )
1044    @torch.no_grad()
1045    def test_replace_select_with_view(self, shape: Tuple[int], dim: int, index: int):
1046        class Select(torch.nn.Module):
1047            def forward(self, x):
1048                return x.select(dim, index)
1049
1050        x = torch.randn(shape)
1051        graph_module = export_to_edge(Select(), (x,)).exported_program().graph_module
1052
1053        p = ReplaceSelectWithViewOpPass()
1054
1055        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1056
1057        # Assert that select op was replaced with view op
1058        self.assertEqual(
1059            count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
1060        )
1061        self.assertEqual(
1062            count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1
1063        )
1064
1065    @parameterized.expand(
1066        [
1067            [(2, 1, 3, 1), 1, 3, torch.float32],
1068            [(2, 1, 5), 1, 0, torch.int64],
1069            [(3, 1, 5), 0, 1, torch.int64],
1070        ]
1071    )
1072    @torch.no_grad()
1073    def test_replace_nop_transpose_with_view(
1074        self,
1075        shape: Tuple[int],
1076        dim0: int,
1077        dim1: int,
1078        dtype: torch.dtype = torch.float32,
1079    ):
1080        class Transpose(torch.nn.Module):
1081            def forward(self, x):
1082                return x.transpose(dim0, dim1)
1083
1084        _max_value = 127
1085        x = (torch.rand(shape) * _max_value).to(dtype=dtype)
1086        graph_module = export_to_edge(Transpose(), (x,)).exported_program().graph_module
1087
1088        p = ReplaceNopTransposeOrPermuteWithViewPass()
1089
1090        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1091
1092        # Assert that transpose op was removed, and a view op was placed instead
1093        self.assertEqual(
1094            count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0
1095        )
1096        self.assertEqual(
1097            count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1
1098        )
1099
1100    @parameterized.expand(
1101        [
1102            # permutations that can be replaced by view
1103            [(3, 1, 3, 1, 4), (0, 2, 4, 1, 3), torch.float32],
1104            [(1, 3, 4), (1, 2, 0), torch.float32],
1105        ]
1106    )
1107    @torch.no_grad()
1108    def test_replace_nop_permute_with_view(self, input_shape, dims, dtype):
1109        class Permute(torch.nn.Module):
1110            def forward(self, x):
1111                return torch.permute(x, dims)
1112
1113        x = torch.randn(input_shape).to(dtype=dtype)
1114        graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module
1115
1116        p = ReplaceNopTransposeOrPermuteWithViewPass()
1117        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1118
1119        # Assert that permute op was removed, and a view op was placed instead
1120        self.assertEqual(
1121            count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0
1122        )
1123        self.assertEqual(
1124            count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 1
1125        )
1126
1127    @parameterized.expand(
1128        [
1129            # permutations replaced by transpose
1130            [(3, 4), [1, 0], torch.float32],
1131            [(3, 4, 6), (0, 2, 1), torch.float32],
1132        ]
1133    )
1134    @torch.no_grad()
1135    def test_replace_permute_with_transpose(self, input_shape, dims, dtype):
1136        class Permute(torch.nn.Module):
1137            def forward(self, x):
1138                return torch.permute(x, dims)
1139
1140        x = torch.randn(input_shape).to(dtype=dtype)
1141        graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module
1142
1143        p = ReplacePermuteWithTransposePass()
1144        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1145
1146        # Assert that permute op was replaced by a transpose op
1147        self.assertEqual(
1148            count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0
1149        )
1150        self.assertEqual(
1151            count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 1
1152        )
1153
1154    @parameterized.expand(
1155        [
1156            # permutations replaced by transpose
1157            [(3, 4), [0, 1], torch.float32],
1158        ]
1159    )
1160    @torch.no_grad()
1161    def test_replace_permute_with_transpose_nop(self, input_shape, dims, dtype):
1162        class Permute(torch.nn.Module):
1163            def forward(self, x):
1164                return torch.permute(x, dims)
1165
1166        x = torch.randn(input_shape).to(dtype=dtype)
1167        graph_module = export_to_edge(Permute(), (x,)).exported_program().graph_module
1168
1169        p = ReplacePermuteWithTransposePass()
1170        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1171
1172        # Assert that permute op was replaced by a transpose op
1173        self.assertEqual(
1174            count_node(graph_after_passes, exir_ops.edge.aten.permute_copy.default), 0
1175        )
1176        self.assertEqual(
1177            count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int), 0
1178        )
1179
1180    def test_replace_aten_linalg_vector_norm_with_cadence_linalg_vector_norm(self):
1181        class LinalgVectorNorm(torch.nn.Module):
1182            def forward(self, x: torch.Tensor):
1183                return torch.linalg.vector_norm(x)
1184
1185        x = torch.randn(32)
1186
1187        graph_module = (
1188            export_to_edge(LinalgVectorNorm(), (x,)).exported_program().graph_module
1189        )
1190
1191        p = ReplaceAtenLinalgVectorNormWithCadenceLinalgVectorNormPass()
1192        graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1193
1194        # Assert that aten.linalg_vector_norm op was replaced by a
1195        # cadence.linalg_vector_norm op
1196        self.assertEqual(
1197            count_node(
1198                graph_after_passes,
1199                exir_ops.edge.aten.linalg_vector_norm.default,
1200            ),
1201            0,
1202        )
1203        self.assertEqual(
1204            count_node(
1205                graph_after_passes, exir_ops.edge.cadence.linalg_vector_norm.default
1206            ),
1207            1,
1208        )
1209
1210
1211class TestReplaceIm2rowWithViewPass(unittest.TestCase):
1212    def test_no_replacement_for_conv(self):
1213        # Create a graph with a single im2row node.
1214        x = torch.randn(1, 3, 224, 224)
1215        pad_value = torch.randn(1)
1216        channels_last = False
1217        gm = single_op_builder(
1218            placeholders=(x, pad_value),
1219            op=exir_ops.edge.cadence.im2row.default,
1220            args=(x, (2, 2), (1, 1), (0, 0), (1, 1), pad_value, channels_last),
1221        )
1222        # Check if graph module is valid by running exportpass on it.
1223        gm = ExportPass().call(gm).graph_module
1224        self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
1225        self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0)
1226
1227        # Apply replacement pass.
1228        p = ReplaceIm2RowWithViewPass()
1229        gm_after_replacement = p.call(gm).graph_module
1230        # Check that no replacement was made.
1231        self.assertEqual(
1232            count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1
1233        )
1234        self.assertEqual(
1235            count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0
1236        )
1237
1238    def test_no_replace_for_dilation(self):
1239        # Create a graph with a single im2row node.
1240        x = torch.randn(1, 3, 5, 7)
1241        pad_value = torch.randn(1)
1242        channels_last = False
1243        gm = single_op_builder(
1244            placeholders=(x, pad_value),
1245            op=exir_ops.edge.cadence.im2row.default,
1246            args=(x, (3, 4), (2, 2), (0, 0), (1, 1), pad_value, channels_last),
1247        )
1248        # Check if graph module is valid by running exportpass on it.
1249        gm = ExportPass().call(gm).graph_module
1250        self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
1251        self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0)
1252
1253        # Apply replacement pass.
1254        p = ReplaceIm2RowWithViewPass()
1255        gm_after_replacement = p.call(gm).graph_module
1256        self.assertEqual(
1257            count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 1
1258        )
1259        self.assertEqual(
1260            count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 0
1261        )
1262
1263    def test_replace_linear_like_conv(self):
1264        # Create a graph with a single im2row node.
1265        in_h, in_w = 13, 15
1266        x = torch.randn(1, 3, in_h, in_w)
1267        pad_value = torch.randn(1)
1268        channels_last = False
1269        gm = single_op_builder(
1270            placeholders=(x, pad_value),
1271            op=exir_ops.edge.cadence.im2row.default,
1272            args=(x, (in_h, in_w), (1, 1), (0, 0), (1, 1), pad_value, channels_last),
1273        )
1274        # Check if graph module is valid by running exportpass on it.
1275        gm = ExportPass().call(gm).graph_module
1276        self.assertEqual(count_node(gm, exir_ops.edge.cadence.im2row.default), 1)
1277        self.assertEqual(count_node(gm, exir_ops.edge.aten.view_copy.default), 0)
1278
1279        # Apply replacement pass.
1280        p = ReplaceIm2RowWithViewPass()
1281        gm_after_replacement = p.call(gm).graph_module
1282        # In this test, the kernel width/height is the same as the input width/height.
1283        self.assertEqual(
1284            count_node(gm_after_replacement, exir_ops.edge.cadence.im2row.default), 0
1285        )
1286        self.assertEqual(
1287            count_node(gm_after_replacement, exir_ops.edge.aten.view_copy.default), 1
1288        )
1289
1290
1291class TestForceChannelLastForConvPass(unittest.TestCase):
1292    def create_conv1d_graphmodule(
1293        self, channels_last: Optional[bool] = None
1294    ) -> torch.fx.GraphModule:
1295        """Helper to create a convolution node.
1296
1297        convolution(
1298            Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding,"
1299            int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)"
1300        """
1301        if channels_last:
1302            x = torch.randn(1, 224, 3)
1303            w = torch.randn(16, 16, 3)
1304        else:
1305            x = torch.randn(1, 3, 224)
1306            w = torch.randn(16, 3, 16)
1307        b = torch.randn(16)
1308        args = (x, w, b, (2, 2), (1, 1), (0, 0), 1)
1309        if channels_last is not None:
1310            args = args + (channels_last,)
1311        return single_op_builder(
1312            placeholders=(x, w, b),
1313            op=exir_ops.edge.cadence.convolution.default,
1314            args=args,
1315        )
1316
1317    def test_conv1d_default_channel_last(self):
1318        # Create a graph with a single convolution node.
1319        # Check if graph module is valid by running exportpass on it.
1320        gm = self.create_conv1d_graphmodule()
1321        gm = ExportPass().call(gm).graph_module
1322        self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)
1323        self.assertEqual(count_node(gm, exir_ops.edge.aten.transpose_copy.int), 0)
1324
1325        # Apply replacement pass.
1326        p = ForceChannelLastForConvPass()
1327        gm_after_replacement = p.call(gm).graph_module
1328        # Check that no replacement was made.
1329        self.assertEqual(
1330            count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default),
1331            1,
1332        )
1333        self.assertEqual(
1334            count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int),
1335            # Two transposes are added, one for the input and one for the output.
1336            3,
1337        )
1338        for node in gm_after_replacement.graph.nodes:
1339            if node.target != exir_ops.edge.cadence.convolution.default:
1340                continue
1341            # Check that the channel_last argument is set to True.
1342            self.assertEqual(len(node.args), 8, f"{node=}")
1343            self.assertTrue(node.args[7])
1344
1345    def test_conv1d_no_transpose_if_already_channel_last(self):
1346        gm = self.create_conv1d_graphmodule(channels_last=True)
1347        gm = ExportPass().call(gm).graph_module
1348        self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)
1349
1350        # Apply replacement pass.
1351        p = ForceChannelLastForConvPass()
1352        gm_after_replacement = p.call(gm).graph_module
1353        # Check that no replacement was made.
1354        self.assertEqual(
1355            count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default),
1356            1,
1357        )
1358        self.assertEqual(
1359            count_node(gm_after_replacement, exir_ops.edge.aten.transpose_copy.int),
1360            0,
1361        )
1362        for node in gm_after_replacement.graph.nodes:
1363            if node.target != exir_ops.edge.cadence.convolution.default:
1364                continue
1365            # Check that the channel_last argument is set to True.
1366            self.assertEqual(len(node.args), 8, f"{node=}")
1367            self.assertTrue(node.args[7])
1368
1369    def create_convolution_graph_module(
1370        self, channels_last: Optional[bool] = None
1371    ) -> torch.fx.GraphModule:
1372        """Helper to create a convolution node.
1373
1374        convolution(
1375            Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding,"
1376            int[] dilation, int groups, bool channel_last=False) -> (Tensor Y)"
1377        """
1378        if channels_last:
1379            x = torch.randn(1, 224, 224, 3)
1380            w = torch.randn(16, 16, 16, 3)
1381        else:
1382            x = torch.randn(1, 3, 224, 224)
1383            w = torch.randn(16, 3, 16, 16)
1384        b = torch.randn(16)
1385        args = (x, w, b, (2, 2), (1, 1), (0, 0), 1)
1386        if channels_last is not None:
1387            args = args + (channels_last,)
1388        return single_op_builder(
1389            placeholders=(x, w, b),
1390            op=exir_ops.edge.cadence.convolution.default,
1391            args=args,
1392        )
1393
1394    def test_convolution_default_channel_last(self):
1395        # Create a graph with a single convolution node.
1396        # Check if graph module is valid by running exportpass on it.
1397        gm = self.create_convolution_graph_module()
1398        gm = ExportPass().call(gm).graph_module
1399        self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)
1400        self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
1401
1402        # Apply replacement pass.
1403        p = ForceChannelLastForConvPass()
1404        gm_after_replacement = p.call(gm).graph_module
1405        # Check that no replacement was made.
1406        self.assertEqual(
1407            count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default),
1408            1,
1409        )
1410        self.assertEqual(
1411            count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default),
1412            # Three permutes are added, two for the input/weights and one for the output.
1413            3,
1414        )
1415        for node in gm_after_replacement.graph.nodes:
1416            if node.target != exir_ops.edge.cadence.convolution.default:
1417                continue
1418            # Check that the channel_last argument is set to True.
1419            self.assertEqual(len(node.args), 8, f"{node=}")
1420            self.assertTrue(node.args[7])
1421
1422    def test_no_transpose_if_already_channel_last(self):
1423        gm = self.create_convolution_graph_module(channels_last=True)
1424        gm = ExportPass().call(gm).graph_module
1425        self.assertEqual(count_node(gm, exir_ops.edge.cadence.convolution.default), 1)
1426
1427        # Apply replacement pass.
1428        p = ForceChannelLastForConvPass()
1429        gm_after_replacement = p.call(gm).graph_module
1430        # Check that no replacement was made.
1431        self.assertEqual(
1432            count_node(gm_after_replacement, exir_ops.edge.cadence.convolution.default),
1433            1,
1434        )
1435        self.assertEqual(
1436            count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default),
1437            0,
1438        )
1439        for node in gm_after_replacement.graph.nodes:
1440            if node.target != exir_ops.edge.cadence.convolution.default:
1441                continue
1442            # Check that the channel_last argument is set to True.
1443            self.assertEqual(len(node.args), 8, f"{node=}")
1444            self.assertTrue(node.args[7])
1445
1446    def create_quantized_convolution_graph_module(
1447        self, channels_last: Optional[bool] = None
1448    ) -> torch.fx.GraphModule:
1449        """Helper to create a quantized conv node.
1450
1451        quantized_conv(
1452            Tensor input, Tensor weight, Tensor bias, int[] stride, SymInt[] padding,
1453            int[] dilation, int groups, int input_zero_point, Tensor weight_zero_point,
1454            Tensor bias_scale, float out_scale, int out_zero_point, Tensor out_multiplier,
1455            Tensor out_shift, bool channel_last=False) -> (Tensor Z)"
1456        """
1457        if channels_last:
1458            x = torch.randn(1, 224, 56, 3)
1459            w = torch.randn(16, 16, 16, 3)
1460        else:
1461            x = torch.randn(1, 3, 224, 56)
1462            w = torch.randn(16, 3, 16, 16)
1463        b = torch.randn(16)
1464        stride = (2, 2)
1465        padding = (0, 0)
1466        dilation = (1, 1)
1467        groups = 1
1468        input_zero_point = 0
1469        w_zero_point = torch.randn(1)
1470        b_scale = torch.randn(1)
1471        out_scale = 1
1472        out_zero_point = 0
1473        out_multiplier = torch.randn(1)
1474        out_shift = torch.randn(1)
1475        args = (
1476            x,
1477            w,
1478            b,
1479            stride,
1480            padding,
1481            dilation,
1482            groups,
1483            input_zero_point,
1484            w_zero_point,
1485            b_scale,
1486            out_scale,
1487            out_zero_point,
1488            out_multiplier,
1489            out_shift,
1490        )
1491        if channels_last is not None:
1492            args = args + (channels_last,)
1493        return single_op_builder(
1494            placeholders=(x, w, b, w_zero_point, b_scale, out_multiplier, out_shift),
1495            op=exir_ops.edge.cadence.quantized_conv.default,
1496            args=args,
1497        )
1498
1499    def test_quantized_convolution_default_channel_last(self):
1500        # Create a graph with a single convolution node.
1501        gm = self.create_quantized_convolution_graph_module()
1502        self.assertEqual(
1503            count_node(gm, exir_ops.edge.cadence.quantized_conv.default), 1
1504        )
1505        self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
1506
1507        # Apply replacement pass.
1508        p = ForceChannelLastForConvPass()
1509        gm_after_replacement = p.call(gm).graph_module
1510        # Check that no replacement was made.
1511        self.assertEqual(
1512            count_node(
1513                gm_after_replacement, exir_ops.edge.cadence.quantized_conv.default
1514            ),
1515            1,
1516        )
1517        # Three permutes are added, two for the input/weights and one for the output.
1518        self.assertEqual(
1519            count_node(gm_after_replacement, exir_ops.edge.aten.permute_copy.default),
1520            3,
1521        )
1522        for node in gm_after_replacement.graph.nodes:
1523            if node.target != exir_ops.edge.cadence.quantized_conv.default:
1524                continue
1525            # Check that the channel_last argument is set to True.
1526            self.assertEqual(len(node.args), 15, f"{node=}")
1527            self.assertTrue(node.args[14])
1528
1529    def test_no_transpose_if_already_quantized_conv_channel_last(self):
1530        # Create a graph with a single im2row node.
1531        gm = self.create_quantized_convolution_graph_module(channels_last=True)
1532        # Check if graph module is valid by running exportpass on it.
1533        gm = ExportPass().call(gm).graph_module
1534        self.assertEqual(
1535            count_node(gm, exir_ops.edge.cadence.quantized_conv.default), 1
1536        )
1537
1538        # Apply replacement pass.
1539        p = ForceChannelLastForConvPass()
1540        gm_after_replacement = p.call(gm).graph_module
1541        # Check that no replacement was made.
1542        self.assertEqual(
1543            count_node(
1544                gm_after_replacement, exir_ops.edge.cadence.quantized_conv.default
1545            ),
1546            1,
1547        )
1548        self.assertEqual(count_node(gm, exir_ops.edge.aten.permute_copy.default), 0)
1549        for node in gm_after_replacement.graph.nodes:
1550            if node.target != exir_ops.edge.cadence.quantized_conv.default:
1551                continue
1552            # Check that the channel_last argument is set to True.
1553            self.assertEqual(len(node.args), 15, f"{node=}")
1554            self.assertTrue(node.args[14])
1555
1556
1557class TestMakeSliceAndCatDimOutermostPass(unittest.TestCase):
1558    def create_slice_graph(
1559        self,
1560        input_shape: Sequence[int],
1561        slice_dim: int,
1562        slice_begin: Optional[int] = None,
1563        slice_end: Optional[int] = None,
1564    ) -> torch.fx.GraphModule:
1565        x = torch.randn(*input_shape)
1566        return single_op_builder(
1567            placeholders=(x,),
1568            op=exir_ops.edge.aten.slice_copy.Tensor,
1569            args=(x, slice_dim, slice_begin, slice_end),
1570        )
1571
1572    def test_slice_no_transpose_if_already_outermost(self):
1573        # Create a graph with a single slice node.
1574        gm = self.create_slice_graph((3, 224, 224), 0, 1, 2)
1575        # Check if graph module is valid by running exportpass on it.
1576        gm = ExportPass().call(gm).graph_module
1577        self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1)
1578
1579        # Apply replacement pass.
1580        gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module
1581
1582        # Assert that no transpose ops were added.
1583        self.assertEqual(
1584            count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
1585            0,
1586        )
1587
1588    def test_slice_no_transpose_if_outermost_dimensions_are_one(self):
1589        # Create a graph with a single slice node on second outermost dimension.
1590        gm = self.create_slice_graph((1, 3, 4, 6), 1, 1, 2)
1591        # Check if graph module is valid by running exportpass on it.
1592        gm = ExportPass().call(gm).graph_module
1593        self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1)
1594
1595        # Apply replacement pass.
1596        gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module
1597
1598        # Assert that no transpose ops were added. The slice is on the second
1599        # outermost dimension, but the outermost dimension is already 1.
1600        self.assertEqual(
1601            count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
1602            0,
1603        )
1604
1605    def test_slice_insert_transpose(self):
1606        # Create a graph with a single slice node.
1607        gm = self.create_slice_graph((1, 3, 4, 6), 2, 1, 2)
1608        # Check if graph module is valid by running exportpass on it.
1609        gm = ExportPass().call(gm).graph_module
1610        self.assertEqual(count_node(gm, exir_ops.edge.aten.slice_copy.Tensor), 1)
1611
1612        # Apply replacement pass.
1613        gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module
1614
1615        # Assert that there are two transpose ops added.
1616        self.assertEqual(
1617            count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
1618            2,
1619        )
1620
1621    def create_cat_graph(
1622        self,
1623        input_shapes: Sequence[Sequence[int]],
1624        cat_dim: int = 0,
1625    ) -> torch.fx.GraphModule:
1626        input_tensors = tuple(torch.randn(s) for s in input_shapes)
1627        return single_op_builder(
1628            placeholders=input_tensors,
1629            op=exir_ops.edge.aten.cat.default,
1630            args=(input_tensors, cat_dim),
1631        )
1632
1633    def test_cat_no_transpose_if_already_outermost(self):
1634        # Create a graph with a single slice node on second outermost dimension.
1635        gm = self.create_cat_graph(input_shapes=((1, 3, 5), (2, 3, 5)), cat_dim=0)
1636        # Check if graph module is valid by running exportpass on it.
1637        gm = ExportPass().call(gm).graph_module
1638        self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1)
1639
1640        # Apply replacement pass.
1641        gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module
1642
1643        # Assert that no transpose ops were added. The slice is on the second
1644        # outermost dimension, but the outermost dimension is already 1.
1645        self.assertEqual(
1646            count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
1647            0,
1648        )
1649
1650    def test_cat_no_transpose_if_outermost_dimensions_are_one(self):
1651        # Create a graph with a single slice node on second outermost dimension.
1652        gm = self.create_cat_graph(input_shapes=((1, 1, 3, 5), (1, 2, 3, 5)), cat_dim=1)
1653        # Check if graph module is valid by running exportpass on it.
1654        gm = ExportPass().call(gm).graph_module
1655        self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1)
1656
1657        # Apply replacement pass.
1658        gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module
1659
1660        # Assert that no transpose ops were added. The slice is on the second
1661        # outermost dimension, but the outermost dimension is already 1.
1662        self.assertEqual(
1663            count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
1664            0,
1665        )
1666
1667    def test_cat_insert_transpose(self):
1668        # Create a graph with a single slice node on second outermost dimension.
1669        gm = self.create_cat_graph(
1670            input_shapes=((1, 1, 3, 5), (1, 1, 3, 3)), cat_dim=-1
1671        )
1672        # Check if graph module is valid by running exportpass on it.
1673        gm = ExportPass().call(gm).graph_module
1674        self.assertEqual(count_node(gm, exir_ops.edge.aten.cat.default), 1)
1675
1676        # Apply replacement pass.
1677        gm_after_pass = MakeSliceAndCatDimOutermostPass()(gm).graph_module
1678
1679        # Assert that transpose ops were added to make cat on outermost dimension.
1680        self.assertEqual(
1681            count_node(gm_after_pass, exir_ops.edge.aten.transpose_copy.int),
1682            3,
1683        )
1684