1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7 8import torch 9from executorch.backends.qualcomm.utils.constants import QCOM_INSERTED_PERMUTE 10 11from executorch.exir.dialects._ops import ops as exir_ops 12from executorch.exir.pass_base import ExportPass, PassResult 13from executorch.exir.passes import dead_code_elimination_pass 14 15 16class FuseConsecutiveTranspose(ExportPass): 17 """ 18 This pass fuses consecutive transpose / permute into one to reduce runtime 19 overhead 20 """ 21 22 def __init__(self): 23 super().__init__() 24 self.op_map = { 25 exir_ops.edge.aten.permute_copy.default, 26 } 27 self.visited = set() 28 self.nodes = [] 29 30 def _traverse(self, node): 31 if node in self.visited or node.target not in self.op_map: 32 return 33 34 self.nodes.append(node) 35 self.visited.add(node) 36 next_users = [n for n in list(node.users) if n.target in self.op_map] 37 if not next_users: 38 return 39 40 if len(next_users) == 1: 41 self._traverse(list(node.users)[0]) 42 else: 43 raise NotImplementedError( 44 f"Check the node {node}, wich encounter mutilple permute output case" 45 ) 46 47 def _fuse(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: 48 graph = graph_module.graph 49 for n in graph_module.graph.nodes: 50 self._traverse(n) 51 if len(self.nodes) > 1: 52 permute_order = [] 53 input_node, output_node = self.nodes[0].args[0], self.nodes[-1] 54 input_shape = input_node.meta["val"].shape 55 axis_order = torch.arange(len(input_shape)).tolist() 56 for node in self.nodes: 57 permute_order.append(node.args[1]) 58 axis_order = [axis_order[i] for i in node.args[1]] 59 with graph.inserting_after(input_node): 60 permute_op = exir_ops.edge.aten.permute_copy.default 61 permute_node = graph.create_node( 62 "call_function", permute_op, (input_node, axis_order) 63 ) 64 users = output_node.users.copy() 65 for user in users: 66 user.replace_input_with(output_node, permute_node) 67 68 # copy metadata 69 permute_node.meta = output_node.meta 70 # Without "qnn_permute", we might obtain wrong input shape 71 if [pn.meta.get(QCOM_INSERTED_PERMUTE) for pn in self.nodes]: 72 permute_node.meta[QCOM_INSERTED_PERMUTE] = True 73 74 # clear current stack 75 self.nodes = [] 76 77 def call(self, graph_module: torch.fx.GraphModule): 78 self._fuse(graph_module) 79 graph_module.recompile() 80 dead_code_elimination_pass(graph_module) 81 return PassResult(graph_module, True) 82