1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from itertools import chain
8
9import torch
10from executorch.backends.example.example_operators.ops import module_to_annotator
11from executorch.exir.dialects._ops import ops as exir_ops
12from executorch.exir.dim_order_utils import get_dim_order
13from executorch.exir.pass_base import ExportPass, PassResult
14from torch.ao.quantization.pt2e.graph_utils import find_sequential_partitions
15
16
17class PermuteMemoryFormatsPass(ExportPass):
18    """
19    This pass will insert to_dim ops to the pattern if satisfis requirement, like pattern_op.permuate_memory_format is set as True.
20    Example 1:
21        before pass: x -> conv -> out
22        after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
23
24        before pass: x -> conv -> conv -> out
25        after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> out
26
27        before pass: x -> conv -> linear -> out
28        after pass: x -> to_dim(channel_last) -> conv -> to_dim_(contiguous) -> to_dim(channel_last) -> linear -> to_dim_(contiguous) -> out
29    """
30
31    def call(  # noqa: suprress function is too complex (13)
32        self, graph_module: torch.fx.GraphModule
33    ) -> PassResult:
34        for pattern in list(module_to_annotator.keys()):
35            pattern_op = module_to_annotator[pattern]
36            if pattern_op.permuate_memory_format:
37                partitions = find_sequential_partitions(
38                    graph_module,
39                    pattern,
40                )
41                for partition in partitions:
42                    # Some unpacking logic to get a flatten exit nodes list
43                    output_nodes = [
44                        node
45                        for node in partition[0].output_nodes
46                        if node.op != "placeholder"
47                    ]
48                    exit_nodes = [output_node.users for output_node in output_nodes]
49                    exit_nodes = list(chain.from_iterable(exit_nodes))
50
51                    """
52                    # Step 1. Insert to_dim op when exit the pattern
53                    # for example, if the pattern is conv, x -> conv -> out will become x -> conv -> to_dim(contiguous) -> out when permute memory format
54                    # for x -> conv -> conv -> out, it will become x -> conv -> to_dim(contiguous) -> conv -> to_dim(contiguous) -> out
55                    """
56                    for exit_node in exit_nodes:
57                        with graph_module.graph.inserting_before(exit_node):
58                            # Handle the case when the pattern output is also the graph output,
59                            # like, x -> conv -> out will become x -> conv -> to_dim(contiguous) -> out
60                            if exit_node.op == "output":
61                                exit_node_args = exit_node.args[0]
62                                exit_to_dim_op = graph_module.graph.call_function(
63                                    exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
64                                    args=exit_node_args,
65                                    kwargs={
66                                        "dtype": torch.float64,
67                                        "dim_order": get_dim_order(
68                                            torch.contiguous_format, 4
69                                        ),
70                                    },
71                                )
72                                # Insert to_dim op and it'll be the return op
73                                _ = graph_module.graph.output((exit_to_dim_op,))
74                                # Remove the old return op.
75                                graph_module.graph.erase_node(exit_node)
76                            # Handle the case when the pattern output is intermediate output,
77                            # like, x -> conv -> conv -> out will become x -> conv -> to_dim(contiguous) -> conv -> out
78                            elif exit_node.op == "call_function":
79                                exit_node_args = []
80                                for exit_node_arg in exit_node.args:
81                                    if (
82                                        isinstance(exit_node_arg, torch.fx.Node)
83                                        and exit_node_arg.op != "placeholder"
84                                    ):
85                                        exit_to_dim_op = graph_module.graph.call_function(
86                                            exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
87                                            args=(exit_node_arg,),
88                                            kwargs={
89                                                "dtype": torch.float64,
90                                                "dim_order": get_dim_order(
91                                                    torch.contiguous_format, 4
92                                                ),
93                                            },
94                                        )
95                                        exit_node_args.append(exit_to_dim_op)
96                                    else:
97                                        exit_node_args.append(exit_node_arg)
98                                exit_node.args = list(exit_node_args)
99
100                    """
101                    # Step 2. Insert to_dim op when enter the pattern. After the first step, we already have to_dim(default) when exiting the pattern.
102                    # Now we need to insert to_dim(channel_last) when enter the pattern.
103                    # for example, if the pattern is conv, x -> conv -> to_dim(contiguous) -> out will become x -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> out
104                    # for x -> conv -> to_dim(contiguous) -> conv -> to_dim(contiguous) -> out, it will become x -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> to_dim(channel_last) -> conv -> to_dim(contiguous) -> out
105                    """
106                    # create the input_node and the to_dim_op map
107                    # for example, if the pattern is conv, x -> conv -> out, node
108                    input_node_map = {}  # key: input_node, value: to_dim_op
109                    to_dim_op_set = set()
110                    for input_node in partition[0].input_nodes:
111                        with graph_module.graph.inserting_after(input_node):
112                            to_dim_op = graph_module.graph.call_function(
113                                # Insert the to_dim op and update input_node_map
114                                exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
115                                args=(input_node,),
116                                kwargs={
117                                    "dtype": torch.float64,
118                                    "dim_order": get_dim_order(torch.channels_last, 4),
119                                },
120                            )
121                            input_node_map[input_node] = to_dim_op
122                            to_dim_op_set.add(to_dim_op)
123
124                    # Update the args to the new to_dim op, skip if it's already set
125                    for input_node in partition[0].input_nodes:
126                        for user in list(input_node.users):
127                            # if user is in to_dim_op_set, it means the user's arg is already set to_dim op
128                            if user not in to_dim_op_set:
129                                user_new_arg = [
130                                    (
131                                        input_node_map[user_arg]
132                                        if user_arg in input_node_map
133                                        else user_arg
134                                    )
135                                    for user_arg in user.args
136                                ]
137                                # Update input node's users arg
138                                user.args = tuple(user_new_arg)
139
140        # Ensure the graph is still valid
141        graph_module.graph.lint()
142        graph_module.recompile()
143        return PassResult(graph_module, True)
144