xref: /aosp_15_r20/external/executorch/backends/transforms/view_copy_to_squeeze_unsqueeze.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-strict
8
9from typing import List, Optional, Union
10
11import torch
12
13from executorch.exir.dialects._ops import ops as exir_ops
14from executorch.exir.pass_base import ExportPass, PassResult
15
16
17class ViewCopyToSqueezeUnsqueezePass(ExportPass):
18    """
19    Replaces view_copy nodes with squeeze_copy.dims nodes if the view node reduces dims of size 1.
20    Replaces view_copy nodes with unsqueeze_copy.default nodes if the view node adds a dim of size 1.
21    """
22
23    def __init__(self) -> None:
24        super().__init__()
25        self.view_copy_op: torch._ops.OpOverload = exir_ops.edge.aten.view_copy.default
26        self.squeeze_op: torch._ops.OpOverload = exir_ops.edge.aten.squeeze_copy.dims
27        self.unsqueeze_op: torch._ops.OpOverload = (
28            exir_ops.edge.aten.unsqueeze_copy.default
29        )
30
31    def is_node_target(
32        self, node: torch.fx.Node, target: torch._ops.OperatorBase
33    ) -> bool:
34        return node.op == "call_function" and node.target == target
35
36    def find_squeeze_dims(
37        self,
38        input_shape: List[int],
39        view_shape: List[int],
40    ) -> Optional[List[int]]:
41        # view_shape should be a subset of input_shape
42        if len(input_shape) <= len(view_shape):
43            return None
44
45        # check that all dims are equal except the removed dims
46        i = 0
47        j = 0
48        idx = []
49        while i < len(input_shape):
50            if input_shape[i] != view_shape[j]:
51                if input_shape[i] == 1:
52                    idx.append(i)
53                    j -= 1
54                    # continue to check remaining dims are equal
55                else:
56                    return None
57            i += 1
58            j += 1
59        return idx
60
61    def find_unsqueeze_dim(
62        self,
63        input_shape: List[int],
64        view_shape: List[int],
65    ) -> Optional[int]:
66        # unsqueeze should increase the length of input_shape by 1
67        if len(view_shape) - len(input_shape) != 1:
68            return None
69
70        # check that all dims are equal except the added dim
71        i = 0
72        j = 0
73        idx = -1
74        while j < len(view_shape):
75            if input_shape[i] != view_shape[j]:
76                if view_shape[j] == 1:
77                    idx = j
78                    i -= 1
79                    # continue to check remaining dims are equal
80                else:
81                    return None
82            i += 1
83            j += 1
84        return idx
85
86    def replace_view_copy_node(
87        self,
88        graph_module: torch.fx.GraphModule,
89        view_node: torch.fx.Node,
90        op: torch._ops.OpOverload,
91        arg: Union[List[int], int],
92    ) -> None:
93        with graph_module.graph.inserting_before(view_node):
94            new_node = graph_module.graph.create_node(
95                "call_function",
96                op,
97                (view_node.args[0], arg),
98            )
99            new_node.meta = view_node.meta
100            view_node.replace_all_uses_with(new_node)
101            graph_module.graph.erase_node(view_node)
102
103    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
104        modified = False
105        for node in graph_module.graph.nodes:
106            if self.is_node_target(node, self.view_copy_op):
107                input_node = node.args[0]
108                input_shape = input_node.meta["val"].shape
109                view_shape = node.args[1]
110                squeeze_dims = self.find_squeeze_dims(input_shape, view_shape)
111                if squeeze_dims:
112                    self.replace_view_copy_node(
113                        graph_module, node, self.squeeze_op, squeeze_dims
114                    )
115                    modified = True
116                    continue
117                unsqueeze_dim = self.find_unsqueeze_dim(input_shape, view_shape)
118                if unsqueeze_dim:
119                    self.replace_view_copy_node(
120                        graph_module, node, self.unsqueeze_op, unsqueeze_dim
121                    )
122                    modified = True
123                    continue
124
125        if modified:
126            graph_module.recompile()
127            graph_module = super().call(graph_module).graph_module
128        return PassResult(graph_module, modified)
129