xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/layout_transform.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import _operator
7from typing import List, Tuple
8
9import torch
10
11from executorch.backends.qualcomm.builders.utils import is_parameter
12from executorch.backends.qualcomm.utils.constants import (
13    QCOM_AXIS_ORDER,
14    QCOM_INSERTED_PERMUTE,
15    QCOM_LAYOUT_CHANGE,
16    QCOM_QUANT_ATTRS,
17    QCOM_REQUANTIZE,
18)
19from executorch.exir.dialects._ops import ops as exir_ops
20from executorch.exir.pass_base import ExportPass, PassResult
21from executorch.exir.sym_util import eval_shape
22
23from .utils import dq_ops, q_ops
24
25
26class LayoutTransform(ExportPass):
27    """
28    QNN delegate requires channel last layout format, this pass aims to
29    help generate the correct transformation by inserting fewest ammount of
30    'permute' operators in the graph.
31    """
32
33    layout_sensitive_ops = {
34        exir_ops.edge.aten.avg_pool2d.default,
35        exir_ops.edge.aten.convolution.default,
36        exir_ops.edge.aten.max_pool2d_with_indices.default,
37        exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
38        exir_ops.edge.aten.native_group_norm.default,
39        exir_ops.edge.aten.pixel_shuffle.default,
40        exir_ops.edge.aten.pixel_unshuffle.default,
41        exir_ops.edge.aten.upsample_bilinear2d.default,
42        exir_ops.edge.aten.upsample_nearest2d.default,
43    }
44
45    layout_agnostic_ops = {
46        exir_ops.edge.aten.add.Tensor,
47        exir_ops.edge.aten.bmm.default,
48        exir_ops.edge.aten.cat.default,
49        exir_ops.edge.aten.ceil.default,
50        exir_ops.edge.aten.clamp.default,
51        exir_ops.edge.aten.constant_pad_nd.default,
52        exir_ops.edge.aten.div.Tensor,
53        exir_ops.edge.aten.full.default,
54        exir_ops.edge.aten.gelu.default,
55        exir_ops.edge.aten.hardswish.default,
56        exir_ops.edge.aten.hardsigmoid.default,
57        exir_ops.edge.aten.hardtanh.default,
58        exir_ops.edge.aten.leaky_relu.default,
59        exir_ops.edge.aten.linear.default,
60        exir_ops.edge.aten._log_softmax.default,
61        exir_ops.edge.aten.mean.dim,
62        exir_ops.edge.aten.mul.Tensor,
63        exir_ops.edge.aten.pow.Tensor_Scalar,
64        exir_ops.edge.aten.prelu.default,
65        exir_ops.edge.aten.relu.default,
66        exir_ops.edge.aten._softmax.default,  # TODO: Need to find a new solution to do "axis_order" to transform axis.
67        exir_ops.edge.aten.sqrt.default,
68        exir_ops.edge.aten.sub.Tensor,
69        exir_ops.edge.aten.sum.dim_IntList,
70        exir_ops.edge.aten.topk.default,
71        exir_ops.edge.aten._to_copy.default,
72        exir_ops.edge.aten.split_with_sizes.default,
73        *q_ops,
74        *dq_ops,
75        _operator.getitem,
76    }
77
78    layout_type = {
79        1: ("N", "N"),
80        2: ("NC", "NC"),
81        3: ("NCW", "NWC"),
82        4: ("NCHW", "NHWC"),
83        5: ("NCDHW", "NDHWC"),
84    }
85
86    @classmethod
87    def get_axis_order(cls, size: List[int], reverse=False) -> Tuple[int]:
88        old_layout, new_layout = cls.layout_type[len(size)]
89        if reverse:
90            old_layout, new_layout = new_layout, old_layout
91        return tuple(old_layout.find(x) for x in new_layout)
92
93    def __init__(
94        self, edge_program: torch.export.ExportedProgram, insert_permute=False
95    ):
96        super(LayoutTransform, self).__init__()
97        self.edge_program = edge_program
98        self.insert_permute = insert_permute
99        self.qdq_opset = {*q_ops, *dq_ops}
100        self.transformed_tag = QCOM_AXIS_ORDER
101
102    def mark_as_transformed(self, node: torch.fx.Node) -> None:
103        if isinstance(node.meta["val"], (tuple, list)):
104            getitem_node = list(node.users.keys())[0]
105            if getitem_node.target.__name__ != "getitem":
106                raise AssertionError(
107                    "Expected node's user to be getitem, "
108                    f"got {getitem_node.target.__name__}"
109                )
110            index = getitem_node.args[1]
111            node.meta[self.transformed_tag] = self.get_axis_order(
112                eval_shape(node.meta["val"][index].shape)
113            )
114        else:
115            node.meta[self.transformed_tag] = self.get_axis_order(
116                eval_shape(node.meta["val"].shape)
117            )
118
119    def is_transformed_node(self, node: torch.fx.Node) -> bool:
120        if not hasattr(node, "meta"):
121            return False
122        return self.transformed_tag in node.meta
123
124    def is_layout_sensitive(self, node: torch.fx.Node) -> bool:
125        return node.target in self.layout_sensitive_ops
126
127    def is_layout_agnostic(self, node: torch.fx.Node) -> bool:
128        if node.target in [
129            exir_ops.edge.aten.mean.dim,
130            exir_ops.edge.aten.sum.dim_IntList,
131        ]:
132            # if dimemsion is not kept, we'll have no clue how to do layout transform
133            if len(node.args) < 3 or not node.args[2]:
134                return False
135        if node.target in self.qdq_opset:
136            return QCOM_REQUANTIZE in node.meta
137        return node.target in self.layout_agnostic_ops
138
139    def is_edge_condition(self, node):
140        if not isinstance(node, torch.fx.Node):
141            return True
142
143        if any(
144            [
145                self.is_transformed_node(node),
146                node.op == "get_attr",
147                (
148                    node.target == exir_ops.edge.aten.permute_copy.default
149                    and node.meta.get(QCOM_INSERTED_PERMUTE, False)
150                ),
151                (
152                    node.op != "output"
153                    and not isinstance(node.meta["val"], (tuple, list))
154                    and len(node.meta["val"].shape) == 0
155                ),
156                is_parameter(node, self.edge_program),
157            ]
158        ):
159            return True
160
161        return False
162
163    def insert_node(self, graph_module, node, revert_layout: bool) -> None:
164        if not self.insert_permute:
165            return
166        with graph_module.graph.inserting_after(node):
167            users = node.users.copy()
168            if isinstance(node.meta["val"], tuple):
169                getitem_node = list(node.users.keys())[0]
170                if getitem_node.target.__name__ != "getitem":
171                    raise AssertionError(
172                        f"Expected bn node's user to be getitem, got {getitem_node.target.__name__}"
173                    )
174                index = getitem_node.args[1]
175                tensor = node.meta["val"][index]
176            else:
177                tensor = node.meta["val"]
178
179            permute = self.create_call_function_node(
180                graph_module,
181                exir_ops.edge.aten.permute_copy.default,
182                (
183                    node,
184                    self.get_axis_order(eval_shape(tensor.shape), revert_layout),
185                ),
186            )
187            permute.meta["val"] = tensor
188            permute.meta[QCOM_QUANT_ATTRS] = node.meta.get(QCOM_QUANT_ATTRS)
189            # we need this to check the annotation boundary
190            permute.meta[QCOM_INSERTED_PERMUTE] = True
191
192            # this is the case when residual connection happened:
193            # e.g. consider following graph
194            # x --> permute --> layer_norm --> permute --> conv2d --> add
195            #               └-------------------------------------┙
196            # we should have premute node to be correctly inserted as:
197            # x --> permute --> layer_norm --> permute --> qnn_permute --> conv2d --> add
198            #               └--------------------------------------> qnn_premute -┙
199            # i.e. insert permute by condition between user and current node
200            #      if there are multiple users included
201            is_node_transformed = self.is_transformed_node(node)
202            for user in users:
203                is_user_transformed = (
204                    self.is_transformed_node(user) or QCOM_LAYOUT_CHANGE in user.meta
205                )
206                # insert permute only in exclusive condition
207                if is_node_transformed != is_user_transformed:
208                    user.replace_input_with(node, permute)
209
210    def create_call_function_node(
211        self,
212        graph_module: torch.fx.GraphModule,
213        target: torch.fx.node.Target,
214        args: Tuple[torch.fx.node.Argument, ...],
215    ):
216        return graph_module.graph.create_node(
217            "call_function",
218            target=target,
219            args=args,
220        )
221
222    def traverse(self, node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> None:
223        for arg in node.args:
224            if isinstance(arg, list):
225                for arg_node in arg:
226                    self.annotate_layout(arg_node, graph_module, revert_layout=False)
227            else:
228                self.annotate_layout(arg, graph_module, revert_layout=False)
229
230        node_users = set(node.users.keys())
231        for user in node_users:
232            self.annotate_layout(user, graph_module, revert_layout=True)
233
234    def annotate_layout(
235        self, node: torch.fx.Node, graph_module: torch.fx.GraphModule, revert_layout
236    ) -> None:
237
238        if self.is_edge_condition(node):
239            return
240        elif self.is_layout_agnostic(node) or self.is_layout_sensitive(node):
241            self.mark_as_transformed(node)
242            self.traverse(node, graph_module)
243        else:
244
245            def check_arg(arg):
246                if self.is_transformed_node(arg):
247                    self.insert_node(graph_module, arg, revert_layout=revert_layout)
248
249            if not revert_layout:
250                self.insert_node(graph_module, node, revert_layout=revert_layout)
251            else:
252                for args in node.args:
253                    if isinstance(args, torch.fx.immutable_collections.immutable_list):
254                        for arg in args:
255                            check_arg(arg)
256                    else:
257                        check_arg(args)
258
259    def call(self, graph_module: torch.fx.GraphModule):
260        graph = graph_module.graph
261        sensitive_nodes = [
262            node for node in graph.nodes if self.is_layout_sensitive(node)
263        ]
264        # perform first run traversal for identifying nodes subjected to layout changes
265        if self.insert_permute:
266            self.insert_permute, self.transformed_tag = False, QCOM_LAYOUT_CHANGE
267            for node in sensitive_nodes:
268                if not self.is_transformed_node(node):
269                    self.mark_as_transformed(node)
270                    self.traverse(node, graph_module)
271            self.insert_permute, self.transformed_tag = True, QCOM_AXIS_ORDER
272
273        for node in sensitive_nodes:
274            if not self.is_transformed_node(node):
275                self.mark_as_transformed(node)
276                self.traverse(node, graph_module)
277
278        graph_module.recompile()
279        if not self.insert_permute:
280            graph_module = super().call(graph_module).graph_module
281        return PassResult(graph_module, True)
282