xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/remove_redundancy.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.exir.dialects._ops import ops as exir_ops
8from executorch.exir.pass_base import ExportPass, PassResult
9from executorch.exir.passes import dead_code_elimination_pass
10
11
12class RemoveRedundancy(ExportPass):
13    """
14    Trim the 'identity' operators to reduce the unnecessary copy overhead.
15    """
16
17    redundant_ops = {
18        torch.clone,
19        torch.ops.aten.clone.default,
20        exir_ops.edge.aten.clone.default,
21        torch.ops.aten.alias.default,
22        exir_ops.edge.aten.alias.default,
23        exir_ops.edge.aten.lift_fresh_copy.default,
24    }
25
26    def __init__(self):
27        super(RemoveRedundancy, self).__init__()
28
29    def _remove(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
30        for n in graph_module.graph.nodes:
31            if n.target not in self.redundant_ops:
32                continue
33
34            to_be_remove = n
35            for user_n in list(n.users.keys()):
36                user_n.replace_input_with(n, n.args[0])
37            graph_module.graph.erase_node(to_be_remove)
38
39    def call(self, graph_module: torch.fx.GraphModule):
40        self._remove(graph_module)
41        graph_module.recompile()
42        dead_code_elimination_pass(graph_module)
43        return PassResult(graph_module, True)
44