xref: /aosp_15_r20/external/executorch/backends/cadence/aot/tests/test_fusion_ops_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3
4import unittest
5
6import executorch.backends.cadence.aot.ops_registrations  # noqa
7import torch
8from executorch.backends.cadence.aot import compiler
9from executorch.backends.cadence.aot.compiler import export_to_edge, quantize_pt2
10from executorch.backends.cadence.aot.fuse_ops import (
11    FuseFullThenReshapePass,
12    FuseMulIntoDequantPass,
13    FuseQuantDequantToRequantizePass,
14    FuseTransposeOpPairsPass,
15)
16from executorch.backends.cadence.aot.graph_builder import GraphBuilder
17from executorch.backends.cadence.aot.pass_utils import count_node
18from executorch.exir.dialects._ops import ops as exir_ops
19from executorch.exir.dialects.edge._ops import EdgeOpOverload
20from torch import nn
21
22
23class TestFusionPassesBase(unittest.TestCase):
24    def check_op_counts(
25        self,
26        graph_module: torch.fx.GraphModule,
27        expected_op_counts: dict[EdgeOpOverload, int],
28    ) -> None:
29        for op, count in expected_op_counts.items():
30            self.assertEqual(count_node(graph_module, op), count)
31
32
33class TestFusionPasses(TestFusionPassesBase):
34    def test_addmm_fusion(self):
35        class AddmmFeasible1(torch.nn.Module):
36            def forward(self, x, y, z):
37                t1 = torch.mm(x, y)
38                return torch.add(t1, z)
39
40        x = torch.randn(3, 5)
41        y = torch.randn(5, 6)
42        z = torch.randn(6)
43
44        graph_module = (
45            compiler.export_to_cadence(AddmmFeasible1(), (x, y, z))
46            .exported_program()
47            .graph_module
48        )
49        graph_module.graph.eliminate_dead_code()
50
51        # Assert that mm and add were fused to addmm
52        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
53        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
54        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
55
56        class AddmmFeasible2(torch.nn.Module):
57            def forward(self, x, y, z):
58                t1 = y.view((8, 6))
59                t2 = torch.mm(x, t1)
60                t3 = t2.view((2, 2, 6))
61                return torch.add(t3, z)
62
63        x = torch.randn(4, 8)
64        y = torch.randn(2, 4, 6)
65        z = torch.randn(6)
66
67        graph_module = (
68            compiler.export_to_cadence(AddmmFeasible2(), (x, y, z))
69            .exported_program()
70            .graph_module
71        )
72        graph_module.graph.eliminate_dead_code()
73        # Assert that mm and add were fused to addmm
74        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
75        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
76        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
77
78        # Bias is a singleton value, broadcastable to output of mm
79        class AddmmFeasible3(torch.nn.Module):
80            def forward(self, x, y):
81                t1 = torch.mm(x, y)
82                return torch.add(t1, torch.ones(1))
83
84        x = torch.randn(3, 5)
85        y = torch.randn(5, 6)
86
87        graph_module = (
88            compiler.export_to_cadence(AddmmFeasible3(), (x, y))
89            .exported_program()
90            .graph_module
91        )
92        graph_module.graph.eliminate_dead_code()
93        # Assert that mm and add were fused to addmm
94        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.addmm.default), 1)
95        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.mm.default), 0)
96        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 0)
97
98        # Bias is not broadcastable to output of mm
99        class AddmmInfeasible1(torch.nn.Module):
100            def forward(self, x, y, z):
101                t1 = y.view((8, 6))
102                t2 = torch.mm(x, t1)
103                t3 = t2.view((2, 2, 6))
104                return torch.add(t3, z)
105
106        x = torch.randn(4, 8)
107        y = torch.randn(2, 4, 6)
108        z = torch.randn(2, 2, 1)
109
110        graph_module = (
111            compiler.export_to_cadence(AddmmInfeasible1(), (x, y, z))
112            .exported_program()
113            .graph_module
114        )
115        graph_module.graph.eliminate_dead_code()
116        # Assert that mm and add were not fused to addmm, since z cannot be
117        # broadcasted to the out of mm.
118        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 1)
119
120        # The add consuming the output of mm has more than one users.
121        class AddmmInfeasible2(torch.nn.Module):
122            def forward(self, x, y, z):
123                t1 = torch.mm(x, y)
124                t2 = torch.add(t1, z)
125                t3 = torch.add(t2, z)
126                return torch.add(t2, t3)
127
128        x = torch.randn(3, 5)
129        y = torch.randn(5, 6)
130        z = torch.randn(6)
131
132        graph_module = (
133            compiler.export_to_cadence(AddmmInfeasible2(), (x, y, z))
134            .exported_program()
135            .graph_module
136        )
137        graph_module.graph.eliminate_dead_code()
138        # Assert that mm and add were not fused to addmm, since add has multiple
139        # users.
140        self.assertEqual(count_node(graph_module, exir_ops.edge.aten.add.Tensor), 3)
141
142    # TODO(matthiascremon): enable that pass with new flow
143    @torch.no_grad()
144    @unittest.expectedFailure
145    def test_legacy_conv_bn_fusion(self):
146        class ModelConvBN(torch.nn.Module):
147            def __init__(self, in_features: int, out_features: int, kernel_size: int):
148                super().__init__()
149                self.conv1d = nn.Conv1d(in_features, out_features, kernel_size)
150                self.bn = nn.BatchNorm1d(out_features)
151
152            def forward(self, x):
153                y = self.conv1d(x)
154                return self.bn(y)
155
156        model = ModelConvBN(64, 1, 2)
157        x = torch.randn(1, 64, 4)
158
159        graph_module = (
160            compiler.export_to_executorch(model.eval(), (x,))
161            .exported_program()
162            .exported_program()
163            .graph_module
164        )
165        # Assert that after running the fusion passes, batchnorm was fused with conv1d
166        self.assertEqual(
167            count_node(graph_module, torch.ops.aten.linear.out)
168            + count_node(graph_module, torch.ops.cadence.convolution.out),
169            1,
170        )
171        self.assertEqual(
172            count_node(
173                graph_module, torch.ops.aten._native_batch_norm_legit_no_training.out
174            ),
175            0,
176        )
177
178    def test_permute_transpose_fusion(self):
179        class PermuteTranspose(torch.nn.Module):
180            def forward(self, x):
181                y = x.permute((0, 2, 4, 1, 3))
182                return y.transpose(0, 1)
183
184        x = torch.randn(3, 1, 3, 1, 4)
185        graph_module = (
186            compiler.export_to_cadence(PermuteTranspose(), (x,))
187            .exported_program()
188            .graph_module
189        )
190        graph_module.graph.eliminate_dead_code()
191        # Assert that permute op was fused with transpose op
192        self.assertEqual(
193            count_node(graph_module, exir_ops.edge.aten.permute_copy.default), 1
194        )
195        self.assertEqual(
196            count_node(graph_module, exir_ops.edge.aten.transpose_copy.int), 0
197        )
198
199    def test_view_fusion(self):
200        class ViewFusion(torch.nn.Module):
201            def forward(self, x):
202                x = x.view([1, 8, 15])
203                x = x.view([1, 1, 120])
204                return x.view([1, 12, 10])
205
206        x = torch.randn(8, 5, 3)
207        graph_module = (
208            compiler.export_to_cadence(ViewFusion(), (x,))
209            .exported_program()
210            .graph_module
211        )
212        graph_module.graph.eliminate_dead_code()
213        # Assert that only one view op remains
214        self.assertEqual(
215            count_node(graph_module, exir_ops.edge.aten.view_copy.default), 1
216        )
217
218    def test_force_quant_dequant_fusion(self):
219        class M(torch.nn.Module):
220            def __init__(self):
221                super().__init__()
222
223            def forward(self, x):
224                x = torch.ops.quantized_decomposed.quantize_per_tensor(
225                    x, 1.2, 3, 0, 127, torch.int8
226                )
227                x = torch.permute(x, [2, 0, 1, 3])
228                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
229                    x, 4.5, 6, 0, 127, torch.int8
230                )
231                return x
232
233        inputs = torch.randn(2, 12, 1, 6)
234        model = M()
235        graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
236
237        graph_module = FuseQuantDequantToRequantizePass(
238            force_quant_dequant_fusion=True
239        )(graph_module).graph_module
240        self.check_op_counts(
241            graph_module,
242            expected_op_counts={
243                # Verify that no dequant/quant pair was replaced with requantize.
244                # quantize -> permute -> dequantize should not be replaced with requantize.
245                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
246                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
247                exir_ops.edge.cadence.requantize.default: 1,
248            },
249        )
250
251    def test_no_replace_quant_permute_dequant_with_requantize(self):
252        class M(torch.nn.Module):
253            def __init__(self):
254                super().__init__()
255
256            def forward(self, x):
257                x = torch.ops.quantized_decomposed.quantize_per_tensor(
258                    x, 1.2, 3, 0, 127, torch.int8
259                )
260                x = torch.permute(x, [2, 0, 1, 3])
261                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
262                    x, 4.5, 6, 0, 127, torch.int8
263                )
264                return x
265
266        inputs = torch.randn(2, 12, 1, 6)
267        model = M()
268        graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
269
270        graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
271        self.check_op_counts(
272            graph_module,
273            expected_op_counts={
274                # Verify that no dequant/quant pair was replaced with requantize.
275                # quantize -> permute -> dequantize should not be replaced with requantize.
276                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
277                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
278                exir_ops.edge.cadence.requantize.default: 0,
279            },
280        )
281
282    def test_replace_quant_view_dequant_with_requantize(self):
283        class M(torch.nn.Module):
284            def __init__(self):
285                super().__init__()
286
287            def forward(self, x):
288                x = torch.ops.quantized_decomposed.quantize_per_tensor(
289                    x, 1.2, 3, 0, 127, torch.int8
290                )
291                x = x.view(-1)
292                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
293                    x, 4.5, 6, 0, 127, torch.int8
294                )
295                return x
296
297        inputs = torch.randn(2, 12, 1, 6)
298        model = M()
299        graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
300        graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
301        graph_module.print_readable()
302
303        self.check_op_counts(
304            graph_module,
305            expected_op_counts={
306                # Verify that no dequant/quant pair was replaced with requantize.
307                # quantize -> permute -> dequantize should not be replaced with requantize.
308                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
309                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
310                exir_ops.edge.cadence.requantize.default: 1,
311            },
312        )
313
314    def test_replace_dequant_quant_with_requantize(self):
315        class M(torch.nn.Module):
316            def __init__(self):
317                super().__init__()
318
319            def forward(self, x):
320                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
321                    x, 1.2, 3, 0, 127, torch.int8
322                )
323                x = torch.permute(x, [2, 0, 1, 3])
324                x = torch.ops.quantized_decomposed.quantize_per_tensor(
325                    x, 4.5, 6, 0, 127, torch.int8
326                )
327                return x
328
329        inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
330        model = M()
331        graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
332        graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
333
334        self.check_op_counts(
335            graph_module,
336            expected_op_counts={
337                # Verify that dequant -> permute -> quant was replaced with permute -> requantize.
338                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
339                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
340                exir_ops.edge.cadence.requantize.default: 1,
341            },
342        )
343
344    def test_replace_dequant_permute_quant_with_requantize(self):
345        class M(torch.nn.Module):
346            def __init__(self):
347                super().__init__()
348
349            def forward(self, x):
350                x = torch.ops.quantized_decomposed.dequantize_per_tensor(
351                    x, 1.2, 3, 0, 127, torch.int8
352                )
353                x = torch.permute(x, [2, 0, 1, 3])
354                x = torch.ops.quantized_decomposed.quantize_per_tensor(
355                    x, 4.5, 6, 0, 127, torch.int8
356                )
357                return x
358
359        inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
360        model = M()
361        graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
362        graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
363
364        self.check_op_counts(
365            graph_module,
366            expected_op_counts={
367                # Verify that dequant -> permute -> quant was replaced with permute -> requantize.
368                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
369                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
370                exir_ops.edge.cadence.requantize.default: 1,
371            },
372        )
373
374    def test_remove_nop_dequant_quant(self):
375        class M(torch.nn.Module):
376            def __init__(self):
377                super(M, self).__init__()
378                self.lin1 = torch.nn.Linear(6, 12, bias=False)
379                self.lin2 = torch.nn.Linear(12, 24, bias=False)
380
381            def forward(self, x):
382                x = self.lin1(x)
383                # redundant dequant+quant will be created around this permute
384                x = torch.permute(x, [0, 2, 1, 3])
385                x = self.lin2(x)
386                return x
387
388        inputs = torch.randn(2, 12, 1, 6)
389        model = M()
390        quantized_model = quantize_pt2(model, (inputs,))
391        graph_module = (
392            export_to_edge(quantized_model, (inputs,)).exported_program().graph_module
393        )
394        graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
395        self.check_op_counts(
396            graph_module,
397            expected_op_counts={
398                # Verify that one dequant/quant pair was removed
399                # Expect 1 quantize ops: 1 input
400                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
401                # Expect 1 dequant op at the end (output of second linear)
402                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
403            },
404        )
405
406    def test_fuse_mul_into_dequant(self):
407        class M(torch.nn.Module):
408            def forward(self, x):
409                x0 = torch.ops.quantized_decomposed.dequantize_per_tensor(
410                    x, 1.5, 0, 0, 255, torch.uint8
411                )
412                x1 = torch.full([4, 32], 3, dtype=torch.float32)
413                x2 = x0 * x1
414                return x2
415
416        inputs = (torch.randint(0, 255, [4, 32], dtype=torch.uint8),)
417        graph_module = export_to_edge(M(), inputs).exported_program().graph_module
418        graph_module = FuseMulIntoDequantPass()(graph_module).graph_module
419
420        # verify that the mul and full ops were removed
421        self.check_op_counts(
422            graph_module,
423            expected_op_counts={
424                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
425                exir_ops.edge.aten.full.default: 0,
426                exir_ops.edge.aten.mul.Tensor: 0,
427            },
428        )
429
430        # verify that the dequant scale value was updated correctly
431        for node in graph_module.graph.nodes:
432            if (
433                node.target
434                == exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
435            ):
436                deq_scale = node.args[1]
437        self.assertEqual(deq_scale, 4.5)
438
439    def test_fuse_then_transpose_pass(self):
440        # Create a graph with full -> transpose.
441        builder = GraphBuilder()
442        full_node = builder.call_operator(
443            op=exir_ops.edge.aten.full.default, args=((2, 3), 1)
444        )
445        transpose_node = builder.call_operator(
446            op=exir_ops.edge.aten.transpose_copy.int,
447            args=(full_node, 0, 1),
448        )
449        permute_node = builder.call_operator(
450            op=exir_ops.edge.aten.permute_copy.default,
451            args=(transpose_node, (1, 0)),
452        )
453        view_node = builder.call_operator(
454            op=exir_ops.edge.aten.view_copy.default,
455            args=(permute_node, (1, 6, 1)),
456        )
457        builder.output(view_node)
458        gm = builder.get_graph_module()
459        self.check_op_counts(
460            gm,
461            expected_op_counts={
462                exir_ops.edge.aten.full.default: 1,
463                exir_ops.edge.aten.transpose_copy.int: 1,
464                exir_ops.edge.aten.permute_copy.default: 1,
465                exir_ops.edge.aten.view_copy.default: 1,
466            },
467        )
468
469        # Check that the pass fuses the full with all other ops (transpose, permute, view).
470        gm_after_pass = FuseFullThenReshapePass()(gm).graph_module
471        self.check_op_counts(
472            gm_after_pass,
473            expected_op_counts={
474                exir_ops.edge.aten.full.default: 1,
475                exir_ops.edge.aten.transpose_copy.int: 0,
476                exir_ops.edge.aten.permute_copy.default: 0,
477                exir_ops.edge.aten.view_copy.default: 0,
478            },
479        )
480
481
482class TestFuseTransposeOpPairsPass(TestFusionPassesBase):
483    def test_fuse_transpose_pairs(self):
484        # Create a graph with transpose -> quant -> transpose.
485        builder = GraphBuilder()
486        x = builder.placeholder("x", torch.randn(2, 3))
487        transpose_node = builder.call_operator(
488            op=exir_ops.edge.aten.transpose_copy.int,
489            args=(x, 0, 1),
490        )
491        quant_node = builder.call_operator(
492            op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
493            args=(transpose_node, 1.2, 3, 0, 127, torch.int8),
494        )
495        transpose_node = builder.call_operator(
496            op=exir_ops.edge.aten.transpose_copy.int,
497            args=(quant_node, 0, 1),
498        )
499        builder.output(transpose_node)
500        gm = builder.get_graph_module()
501        self.check_op_counts(
502            gm,
503            expected_op_counts={
504                exir_ops.edge.aten.transpose_copy.int: 2,
505                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
506            },
507        )
508
509        # Check that the pass fuses the two transpose ops.
510        gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
511        self.check_op_counts(
512            gm_after_pass,
513            expected_op_counts={
514                exir_ops.edge.aten.transpose_copy.int: 0,
515                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
516            },
517        )
518
519    def test_no_fusion_for_transpose_pairs(self):
520        # Create a graph with transpose -> quant -> transpose.
521        builder = GraphBuilder()
522        x = builder.placeholder("x", torch.randn(2, 3, 4))
523        transpose_node = builder.call_operator(
524            op=exir_ops.edge.aten.transpose_copy.int,
525            args=(x, 0, 1),
526        )
527        quant_node = builder.call_operator(
528            op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
529            args=(transpose_node, 1.2, 3, 0, 127, torch.int8),
530        )
531        transpose_node = builder.call_operator(
532            op=exir_ops.edge.aten.transpose_copy.int,
533            args=(quant_node, 1, 2),
534        )
535        builder.output(transpose_node)
536        gm = builder.get_graph_module()
537        self.check_op_counts(
538            gm,
539            expected_op_counts={
540                exir_ops.edge.aten.transpose_copy.int: 2,
541                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
542            },
543        )
544
545        # No fusion.
546        gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
547        self.check_op_counts(
548            gm_after_pass,
549            expected_op_counts={
550                exir_ops.edge.aten.transpose_copy.int: 2,
551                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
552            },
553        )
554
555    def test_fusion_for_forked_transposes(self):
556        # Create a graph with transpose -> quant -> transpose.
557        builder = GraphBuilder()
558        x = builder.placeholder("x", torch.randn(2, 3, 4, dtype=torch.float32))
559        transpose_node = builder.call_operator(
560            op=exir_ops.edge.aten.transpose_copy.int,
561            args=(x, 0, 1),
562        )
563        num_forks = 3
564        outputs = []
565        for _ in range(num_forks):
566            quant_node = builder.call_operator(
567                op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
568                args=(transpose_node, 1.2, 3, 0, 127, torch.int8),
569            )
570            outputs.append(
571                builder.call_operator(
572                    op=exir_ops.edge.aten.transpose_copy.int,
573                    args=(quant_node, 0, 1),
574                )
575            )
576        builder.output(outputs)
577        gm = builder.get_graph_module()
578        self.check_op_counts(
579            gm,
580            expected_op_counts={
581                exir_ops.edge.aten.transpose_copy.int: num_forks + 1,
582                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: num_forks,
583            },
584        )
585
586        # Fuse the all the transpose ops.
587        gm_after_pass = FuseTransposeOpPairsPass()(gm).graph_module
588        self.check_op_counts(
589            gm_after_pass,
590            expected_op_counts={
591                exir_ops.edge.aten.transpose_copy.int: 0,
592                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: num_forks,
593            },
594        )
595