1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7
8import torch
9from executorch.exir.dialects._ops import ops as exir_ops
10from executorch.exir.dim_order_utils import get_dim_order
11from executorch.exir.pass_base import ExportPass, PassResult
12
13
14class MergeToDimPass(ExportPass):
15    """
16    This pass will insert to_dim ops to the pattern if satisfis requirement, like pattern_op.permuate_memory_format is set as True.
17    Example:
18        # Done for 1 to 1
19        before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
20        after pass: x -> to_dim(channel_last) -> conv -> conv -> to_dim_(contiguous) -> out
21
22        # Not Done for 1 to N
23        before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
24                                                                 |-------------> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
25        after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
26                                                               |--------------> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
27
28        # Not Done for N to 1
29        before pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
30                     y -> to_dim(channel_last) -> conv -> to_dim_(contiguous) ---------|
31        after pass:  x -> to_dim(channel_last) -> conv -> conv -> to_dim_(contiguous) -> out
32                     y -> to_dim(channel_last) -> conv-----|
33
34        # Not Done for N to N
35    """
36
37    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
38        for node in graph_module.graph.nodes:
39            if node.target == exir_ops.edge.dim_order_ops._to_dim_order_copy.default:
40                # print(node, node.args, list(node.users), list(list(node.users)[0].args))
41                if len(node.users) == 1 and len(list(node.users)[0].args) == 2:
42                    args_map = {}
43                    node_kwargs = node.args[-1]
44                    node_users = list(node.users)
45
46                    in_to_dim_node_dim_order = node_kwargs["dim_order"]
47                    in_to_dim_node_dtype = node_kwargs["dtype"]
48                    out_to_dim_node = node_users[0]
49                    out_to_dim_node_kwargs = out_to_dim_node.args[-1]
50                    out_to_dim_node_dim_order = out_to_dim_node_kwargs["dim_order"]
51                    out_to_dim_node_dtype = out_to_dim_node_kwargs["dtype"]
52
53                    if (
54                        in_to_dim_node_dtype == out_to_dim_node_dtype
55                        and in_to_dim_node_dim_order
56                        == get_dim_order(torch.channels_last, 4)
57                        and out_to_dim_node_dim_order
58                        == get_dim_order(torch.contiguous_format, 4)
59                    ):
60
61                        out_to_dim_node_users = list(out_to_dim_node.users)
62                        assert len(out_to_dim_node_users) == 1
63                        out_to_dim_node_user = out_to_dim_node_users[0]
64                        args_map[out_to_dim_node] = node.args[0]
65                        out_to_dim_node_user_new_args = [
66                            args_map[out_to_dim_node] if arg in args_map else arg
67                            for arg in out_to_dim_node_user.args
68                        ]
69                        print("out_to_dim_node_user.args: ", out_to_dim_node_user.args)
70                        print(
71                            "out_to_dim_node_user_new_args: ",
72                            out_to_dim_node_user_new_args,
73                        )
74                        out_to_dim_node_user.args = tuple(out_to_dim_node_user_new_args)
75
76                        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
77                        #  torch.nn.modules.module.Module]` is not a function.
78                        graph_module.erase_node(out_to_dim_node)
79                        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
80                        #  torch.nn.modules.module.Module]` is not a function.
81                        graph_module.erase_node(node)
82            # TODO: Handle other merging rules, including 1->N, N->1, N->N
83        return PassResult(graph_module, True)
84