1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.exir.dialects._ops import ops as exir_ops 8from executorch.exir.pass_base import ExportPass, PassResult 9from executorch.exir.passes import dead_code_elimination_pass 10 11 12class RemoveRedundancy(ExportPass): 13 """ 14 Trim the 'identity' operators to reduce the unnecessary copy overhead. 15 """ 16 17 redundant_ops = { 18 torch.clone, 19 torch.ops.aten.clone.default, 20 exir_ops.edge.aten.clone.default, 21 torch.ops.aten.alias.default, 22 exir_ops.edge.aten.alias.default, 23 exir_ops.edge.aten.lift_fresh_copy.default, 24 } 25 26 def __init__(self): 27 super(RemoveRedundancy, self).__init__() 28 29 def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule: 30 for n in graph_module.graph.nodes: 31 if n.target not in self.redundant_ops: 32 continue 33 34 to_be_remove = n 35 for user_n in list(n.users.keys()): 36 user_n.replace_input_with(n, n.args[0]) 37 graph_module.graph.erase_node(to_be_remove) 38 39 def call(self, graph_module: torch.fx.GraphModule): 40 self._remove(graph_module) 41 graph_module.recompile() 42 dead_code_elimination_pass(graph_module) 43 return PassResult(graph_module, True) 44