xref: /aosp_15_r20/external/executorch/backends/cadence/aot/tests/test_reorder_ops_passes.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
2
3
4import unittest
5
6import executorch.backends.cadence.aot.ops_registrations  # noqa
7import torch
8from executorch.backends.cadence.aot.compiler import (
9    export_to_edge,
10    quantize_and_export_to_cadence,
11)
12from executorch.backends.cadence.aot.fuse_ops import FuseQuantDequantToRequantizePass
13from executorch.backends.cadence.aot.pass_utils import (
14    count_node,
15    get_compute_nodes_in_gm,
16    nodes_not_adjacent_in_gm,
17    nodes_not_connected_in_gm,
18)
19from executorch.backends.cadence.aot.reorder_ops import (
20    AdvanceQuantizeOpAboveDefInBranchPass,
21    PostponeDequantizeOpBelowUseChainPass,
22    PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView,
23)
24from executorch.exir.dialects._ops import ops as exir_ops
25
26
27class TestReorderPasses(unittest.TestCase):
28    def test_sink_dequantize(self):
29        class M(torch.nn.Module):
30            def __init__(self):
31                super().__init__()
32                self.linear = torch.nn.Linear(6, 12, bias=False)
33
34            def forward(self, x, y):
35                x1 = self.linear(x)
36                y1 = self.linear(y)
37                x2 = torch.ops.aten.abs(x1)
38                return torch.ops.aten.cat((x2, y1))
39
40        inputs = (torch.randn(32, 6), torch.randn(32, 6))
41        graph_module = (
42            quantize_and_export_to_cadence(M(), inputs).exported_program().graph_module
43        )
44        # Expect the SinkDequant pass to move dequant(y) from above the relu to just below it
45        self.assertTrue(
46            nodes_not_adjacent_in_gm(
47                graph_module,
48                exir_ops.edge.aten.abs.default,
49                exir_ops.edge.aten.cat.default,
50            ),
51        )
52        self.assertTrue(
53            nodes_not_adjacent_in_gm(
54                graph_module,
55                exir_ops.edge.cadence.dequantize_per_tensor.default,
56                exir_ops.edge.cadence.dequantize_per_tensor.default,
57            ),
58        )
59
60    def test_advance_branched_quantize(self):
61        class ReorderOpsBranch(torch.nn.Module):
62            def forward(self, x):
63                x = x.view((32, 6))
64                x1 = torch.slice_copy(x, dim=0, start=0, end=6, step=1)
65                x1 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
66                    x1, 0.1, 10, 0, 255, torch.uint8
67                )
68                x2 = torch.slice_copy(x, dim=0, start=6, end=12, step=1)
69                x2 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
70                    x2, 0.1, 10, 0, 255, torch.uint8
71                )
72                x3 = torch.slice_copy(x, dim=0, start=12, end=18, step=1)
73                x3 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
74                    x3, 0.1, 10, 0, 255, torch.uint8
75                )
76                x4 = torch.slice_copy(x, dim=0, start=18, end=24, step=1)
77                x4 = exir_ops.edge.quantized_decomposed.quantize_per_tensor(
78                    x4, 0.2, 4, 0, 255, torch.uint8
79                )
80                return (x1, x2, x3, x4)
81
82        model = ReorderOpsBranch()
83        X = torch.randn(64, 3)
84        graph_module = export_to_edge(model, (X,)).exported_program().graph_module
85        graph_module = AdvanceQuantizeOpAboveDefInBranchPass()(
86            graph_module
87        ).graph_module
88        graph_module.graph.eliminate_dead_code()
89        nodes = get_compute_nodes_in_gm(graph_module)
90        # The quantize op should be hoisted to dominate the branch
91        self.assertTrue(
92            nodes[0] == exir_ops.edge.quantized_decomposed.quantize_per_tensor
93        )
94        # There should be 5 quantize ops: the 4 originally present in the model,
95        # and the one that was hoisted above the slices
96        self.assertEqual(
97            count_node(
98                graph_module,
99                exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
100            ),
101            5,
102        )
103        # Ensure none of the slice nodes were erroneously removed
104        self.assertEqual(
105            count_node(
106                graph_module,
107                exir_ops.edge.aten.slice_copy.Tensor,
108            ),
109            4,
110        )
111        # Each of the 4 original quant ops should now be paired with a dequant op
112        self.assertEqual(
113            count_node(
114                graph_module,
115                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
116            ),
117            4,
118        )
119        graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
120        # We expect 3 dequant/quant pairs to be removed because they have matching params,
121        # leaving a single dequant/quant pair that is then merged into a requantize op
122        self.assertEqual(
123            count_node(
124                graph_module,
125                exir_ops.edge.cadence.requantize.default,
126            ),
127            1,
128        )
129
130    @torch.no_grad()
131    def test_advance_quantize(self):
132        class ReorderOpsChain(torch.nn.Module):
133            def __init__(self):
134                super().__init__()
135                self.linear = torch.nn.Linear(6, 12, bias=False)
136
137            def forward(self, x):
138                x = x.permute([1, 0, 3, 2])
139                x = self.linear(x)
140                return x
141
142        model = ReorderOpsChain()
143        X = torch.randn(16, 1, 6, 32)
144
145        graph_module = (
146            quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module
147        )
148        # Assert that the quant node is no longer the successor of
149        # permute node.
150        self.assertTrue(
151            nodes_not_connected_in_gm(
152                graph_module,
153                exir_ops.edge.aten.permute_copy.default,
154                exir_ops.edge.cadence.quantize_per_tensor.default,
155            ),
156        )
157        # Assert that permute node is the successor of quant node
158        self.assertFalse(
159            nodes_not_connected_in_gm(
160                graph_module,
161                exir_ops.edge.cadence.quantize_per_tensor.default,
162                exir_ops.edge.aten.permute_copy.default,
163            ),
164        )
165
166    def test_postpone_dequantize(self):
167        class ReorderOpsChain(torch.nn.Module):
168            def __init__(self):
169                super().__init__()
170                self.linear = torch.nn.Linear(6, 12, bias=False)
171
172            def forward(self, x):
173                x = self.linear(x)
174                x = x.permute([1, 0, 3, 2])
175                return x
176
177        model = ReorderOpsChain()
178        X = torch.randn(1, 16, 32, 6)
179
180        graph_module = (
181            quantize_and_export_to_cadence(model, (X,)).exported_program().graph_module
182        )
183        # Assert that the dequant node is no longer the predecessor of the permute node
184        self.assertTrue(
185            nodes_not_connected_in_gm(
186                graph_module,
187                exir_ops.edge.cadence.dequantize_per_tensor.default,
188                exir_ops.edge.aten.permute_copy.default,
189            ),
190        )
191        # Assert that dequant node is the successor of permute node
192        self.assertFalse(
193            nodes_not_connected_in_gm(
194                graph_module,
195                exir_ops.edge.aten.permute_copy.default,
196                exir_ops.edge.cadence.dequantize_per_tensor.default,
197            ),
198        )
199
200    def test_postpone_dequantize_branched(self):
201        class ReorderOpsBranch(torch.nn.Module):
202            def __init__(self):
203                super().__init__()
204                self.linear = torch.nn.Linear(3, 12, bias=False)
205
206            def forward(self, x):
207                x0 = exir_ops.edge.quantized_decomposed.dequantize_per_tensor(
208                    x, 0.1, 10, 0, 255, torch.uint8
209                )
210                x0 = torch.squeeze(x0, 0)
211                x1 = torch.slice_copy(x0, dim=0, start=0, end=6, step=1)
212                x1 = self.linear(x1)
213
214                x2 = torch.slice_copy(x0, dim=0, start=6, end=12, step=1)
215                x2 = self.linear(x2)
216
217                x3 = torch.slice_copy(x0, dim=0, start=12, end=18, step=1)
218                x3 = self.linear(x3)
219
220                return (x1, x2, x3)
221
222        model = ReorderOpsBranch()
223        X = torch.randint(0, 255, [1, 18, 3], dtype=torch.uint8)
224        graph_module = export_to_edge(model, (X,)).exported_program().graph_module
225        graph_module = PostponeDequantizeOpBelowUseChainPass()(
226            graph_module
227        ).graph_module
228        graph_module.graph.eliminate_dead_code()
229
230        # Asset that the dequant node was split into 4, one per branch
231        self.assertEqual(
232            count_node(
233                graph_module,
234                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
235            ),
236            3,
237        )
238
239        # Assert that the dequant node is no longer the predecessor of the squeeze node
240        self.assertTrue(
241            nodes_not_connected_in_gm(
242                graph_module,
243                exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
244                exir_ops.edge.aten.squeeze_copy.dims,
245            ),
246        )
247        # Assert that dequant node is not predecessor of slice (it should've been moved below slice)
248        self.assertTrue(
249            nodes_not_connected_in_gm(
250                graph_module,
251                exir_ops.edge.cadence.dequantize_per_tensor.default,
252                exir_ops.edge.aten.slice_copy.Tensor,
253            ),
254        )
255
256    # 4d -> permute -> 4d -> view -> 3d
257    def test_permute3_view4_chains(self):
258        class PermuteViewChain(torch.nn.Module):
259            def forward(self, x):
260                # x is [3, 1, 768]
261                x = x.view((3, 12, 64))
262                # x is [3, 12, 64]
263                x = x.permute([1, 0, 2])
264                # x is [12, 3, 64]
265                x = x.view((1, 12, 3, 64))
266                # x is [1, 12, 3, 64]
267                x = x.permute([0, 1, 3, 2])
268                # x is [1, 12, 64, 3]
269                return x
270
271        model = PermuteViewChain()
272        X = torch.randn(3, 1, 768)
273        graph_module = export_to_edge(model, (X,)).exported_program().graph_module
274
275        # Performing transform
276        graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
277            graph_module
278        ).graph_module
279        graph_module.graph.eliminate_dead_code()
280
281        # Assert the order becomes view, view, permute, permute
282        nodes = get_compute_nodes_in_gm(graph_module)
283        self.assertEqual(len(nodes), 4)
284        self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
285        self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy)
286        self.assertTrue(nodes[2] == exir_ops.edge.aten.permute_copy)
287        self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy)
288
289    # 3d -> permute -> 3d -> view -> 4d
290    def test_permute4_view3_chains(self):
291        class PermuteViewChain(torch.nn.Module):
292            def forward(self, x):
293                # x is [3, 1, 768]
294                x = x.view((1, 3, 12, 64))
295                # x is [1, 3, 12, 64]
296                x = x.permute([3, 1, 0, 2])
297                # x is [64, 3, 1, 12]
298                x = x.view((64, 3, 12))
299                # x is [64, 3, 12]
300                x = x.permute([2, 1, 0])
301                # x is [12, 3, 64]
302                return x
303
304        model = PermuteViewChain()
305        X = torch.randn(3, 1, 768)
306        graph_module = export_to_edge(model, (X,)).exported_program().graph_module
307
308        # Performing transform
309        graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
310            graph_module
311        ).graph_module
312        graph_module.graph.eliminate_dead_code()
313
314        # Assert the order becomes view, view, permute, permute
315        nodes = get_compute_nodes_in_gm(graph_module)
316        self.assertEqual(len(nodes), 4)
317        self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
318        self.assertTrue(nodes[1] == exir_ops.edge.aten.view_copy)
319        self.assertTrue(nodes[2] == exir_ops.edge.aten.permute_copy)
320        self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy)
321
322    # Negative test case where the transform should not happen.
323    # permute->4d->view->3d where the view not only removes the dimension whose
324    # size is 1 (this is ok), but also changes the size of the dimensions (not ok).
325    def test_permute_view_chains_neg(self):
326        class PermuteViewChain(torch.nn.Module):
327            def forward(self, x):
328                # x is [3, 1, 768]
329                x = x.view((1, 3, 12, 64))
330                # x is [1, 3, 12, 64]
331                x = x.permute([3, 1, 0, 2])
332                # x is [64, 3, 1, 12]
333                x = x.view((64, 6, 6))
334                # x is [64, 6, 6]
335                x = x.permute([2, 1, 0])
336                # x is [6, 6, 64]
337                return x
338
339        model = PermuteViewChain()
340        X = torch.randn(3, 1, 768)
341        graph_module = export_to_edge(model, (X,)).exported_program().graph_module
342
343        # Performing transform (nothing should happen)
344        graph_module = PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView()(
345            graph_module
346        ).graph_module
347        graph_module.graph.eliminate_dead_code()
348
349        # Assert the order is still view, permute, view, permute
350        nodes = get_compute_nodes_in_gm(graph_module)
351        self.assertEqual(len(nodes), 4)
352        self.assertTrue(nodes[0] == exir_ops.edge.aten.view_copy)
353        self.assertTrue(nodes[1] == exir_ops.edge.aten.permute_copy)
354        self.assertTrue(nodes[2] == exir_ops.edge.aten.view_copy)
355        self.assertTrue(nodes[3] == exir_ops.edge.aten.permute_copy)
356