xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/fuse_consecutive_transpose.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import torch
9from executorch.backends.qualcomm.utils.constants import QCOM_INSERTED_PERMUTE
10
11from executorch.exir.dialects._ops import ops as exir_ops
12from executorch.exir.pass_base import ExportPass, PassResult
13from executorch.exir.passes import dead_code_elimination_pass
14
15
16class FuseConsecutiveTranspose(ExportPass):
17    """
18    This pass fuses consecutive transpose / permute into one to reduce runtime
19    overhead
20    """
21
22    def __init__(self):
23        super().__init__()
24        self.op_map = {
25            exir_ops.edge.aten.permute_copy.default,
26        }
27        self.visited = set()
28        self.nodes = []
29
30    def _traverse(self, node):
31        if node in self.visited or node.target not in self.op_map:
32            return
33
34        self.nodes.append(node)
35        self.visited.add(node)
36        next_users = [n for n in list(node.users) if n.target in self.op_map]
37        if not next_users:
38            return
39
40        if len(next_users) == 1:
41            self._traverse(list(node.users)[0])
42        else:
43            raise NotImplementedError(
44                f"Check the node {node}, wich encounter mutilple permute output case"
45            )
46
47    def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
48        graph = graph_module.graph
49        for n in graph_module.graph.nodes:
50            self._traverse(n)
51            if len(self.nodes) > 1:
52                permute_order = []
53                input_node, output_node = self.nodes[0].args[0], self.nodes[-1]
54                input_shape = input_node.meta["val"].shape
55                axis_order = torch.arange(len(input_shape)).tolist()
56                for node in self.nodes:
57                    permute_order.append(node.args[1])
58                    axis_order = [axis_order[i] for i in node.args[1]]
59                with graph.inserting_after(input_node):
60                    permute_op = exir_ops.edge.aten.permute_copy.default
61                    permute_node = graph.create_node(
62                        "call_function", permute_op, (input_node, axis_order)
63                    )
64                    users = output_node.users.copy()
65                    for user in users:
66                        user.replace_input_with(output_node, permute_node)
67
68                    # copy metadata
69                    permute_node.meta = output_node.meta
70                    # Without "qnn_permute", we might obtain wrong input shape
71                    if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]:
72                        permute_node.meta[QCOM_INSERTED_PERMUTE] = True
73
74            # clear current stack
75            self.nodes = []
76
77    def call(self, graph_module: torch.fx.GraphModule):
78        self._fuse(graph_module)
79        graph_module.recompile()
80        dead_code_elimination_pass(graph_module)
81        return PassResult(graph_module, True)
82