xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/expand_broadcast_tensor_shape.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.exir.dialects._ops import ops as exir_ops
9from executorch.exir.pass_base import ExportPass, PassResult
10from executorch.exir.passes import dead_code_elimination_pass
11
12
13class ExpandBroadcastTensorShape(ExportPass):
14    """
15    Make tensors have same rank for layout-transform to work properly.
16    """
17
18    def __init__(self):
19        super(ExpandBroadcastTensorShape, self).__init__()
20        self.broadcast_op_targets = [
21            exir_ops.edge.aten.add.Tensor,
22            exir_ops.edge.aten.sub.Tensor,
23            exir_ops.edge.aten.mul.Tensor,
24            exir_ops.edge.aten.div.Tensor,
25        ]
26
27    def traverse_broadcast_node(self, graph_module: torch.fx.GraphModule):
28        for node in graph_module.graph.nodes:
29            if node.target in self.broadcast_op_targets:
30                for arg in node.args:
31                    input_rank = len(arg.meta["val"].shape)
32                    output_rank = len(node.meta["val"].shape)
33                    if input_rank != output_rank:
34                        with graph_module.graph.inserting_after(arg):
35                            new_rank = [1] * (output_rank - input_rank) + list(
36                                arg.meta["val"].shape
37                            )
38                            users = list(arg.users.keys())
39                            reshape_node = graph_module.graph.create_node(
40                                "call_function",
41                                exir_ops.edge.aten.view_copy.default,
42                                (arg, tuple(new_rank)),
43                            )
44                            # meta needs to be copied elementwisely for fake-tensor
45                            # to be updated correctly and not affect meta of arg
46                            for k, v in arg.meta.items():
47                                reshape_node.meta[k] = v
48                            reshape_node.meta["val"] = reshape_node.meta["val"].reshape(
49                                new_rank
50                            )
51                            for user in users:
52                                user.replace_input_with(arg, reshape_node)
53
54    def call(self, graph_module: torch.fx.GraphModule):
55        self.traverse_broadcast_node(graph_module)
56        graph_module.recompile()
57        dead_code_elimination_pass(graph_module)
58        return PassResult(graph_module, True)
59