xref: /aosp_15_r20/external/executorch/backends/transforms/addmm_mm_to_linear.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.exir.dialects._ops import ops as exir_ops
9from executorch.exir.pass_base import ExportPass, PassResult
10
11from executorch.exir.sym_util import eval_shape, eval_shape_upper_bound
12
13
14_int64_max_dim_val = torch.iinfo(torch.int64).max - 1
15
16
17def get_shape(input_node: torch.fx.Node):
18    """
19    If shape is symbolic then evaluate shape, otherwise if it has upperbound
20    shape, then return upperbound shape.
21    Note that we must check for upperbound because by default upperbound is int64_max
22    """
23    input_val = input_node.meta["val"]
24    upper_bound_shape = eval_shape_upper_bound(input_val.shape)
25    for i in range(len(input_val.shape)):
26        # Unbounded shape get int64 max values assigned to it.
27        # This is just hacking around it when export with dynamic shape
28        # does not use constraint api but instead just traces the
29        # modelw with tensors of the max size
30        if upper_bound_shape[i] >= _int64_max_dim_val:
31            return eval_shape(input_val.shape)
32    return upper_bound_shape
33
34
35def get_dqlinear_input(node: torch.fx.Node):
36    ops = exir_ops.edge
37    node_to_backtrack = node
38    # First find the activation input
39    # Then trace it backwards through all view copies
40    # Until you find dequant node.
41    # if any of the nodes, during backtracking, is not view_copy
42    # then break
43    while node_to_backtrack.op != "placeholder":
44        if (
45            node_to_backtrack.op == "call_function"
46            and node_to_backtrack.target
47            == ops.quantized_decomposed.dequantize_per_tensor.tensor
48        ):
49            return node_to_backtrack
50        if (
51            node_to_backtrack.op == "call_function"
52            and node_to_backtrack.target == ops.aten.view_copy.default
53        ):
54            node_to_backtrack = node_to_backtrack.args[0]
55        else:
56            return None
57    return None
58
59
60def replace_linear_view_copy_input_output(graph: torch.fx.Graph) -> torch.fx.Graph:
61    """
62    Replaces pattern: x -> view_copy -> view_copy -> linear -> view_copy -> y
63    with
64    x -> linear -> y
65    Linear nodes can handle input tensor with > 2 dimensions.
66    """
67    ops = exir_ops.edge
68    for node in graph.nodes:
69        if node.op == "call_function" and (node.target == ops.aten.linear.default):
70            input_node = node.args[0]
71            dqlinear_input = get_dqlinear_input(input_node)
72            if dqlinear_input is not None and dqlinear_input != input_node:
73                if len(input_node.args[0].users) == 1:
74                    input_node.replace_all_uses_with(dqlinear_input)
75                else:
76                    print(
77                        f"{input_node} has more than one user. Users: {input_node.users}"
78                    )
79            if len(node.users) == 1:
80                users = list(node.users)
81                maybe_view_copy = users[0]
82                if maybe_view_copy.op == "call_function" and (
83                    maybe_view_copy.target == ops.aten.view_copy.default
84                ):
85                    # Must update the input node since replaced the original node
86                    input_node = node.args[0]
87                    input_shape = list(get_shape(input_node))
88                    weight_node = node.args[1]
89                    if "val" not in weight_node.meta:
90                        raise ValueError(f"Val not found meta of node {weight_node}")
91                    weight_val = weight_node.meta["val"]
92                    output_channels = weight_val.shape[0]
93                    output_shape = input_shape
94                    output_shape[-1] = output_channels
95                    view_copy_out_shape = list(get_shape(maybe_view_copy))
96                    if output_shape == view_copy_out_shape:
97                        maybe_view_copy.replace_all_uses_with(node)
98    graph.eliminate_dead_code()
99    return graph
100
101
102def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph:
103    """
104    Replace calls to addmm/mm with linear node
105    Reason is that it simplifies the downstream logic of lowering to just linear node.
106    Furthermore it also removes various view_copy nodes. These nodes have been absorbed
107    by delegated by ignoring them entirely.
108    Furthermore, removing view_copy nodes has the advantage of not having to match
109    against view copies which simplifies the pattern that has to be matched.
110    Simplified patterns will be less brittle since symbolic ints and sizes creeping into
111    the graph was making them harder to match.
112    """
113    ops = exir_ops.edge
114    for node in graph.nodes:
115        if node.op == "call_function" and (
116            node.target == ops.aten.mm.default or node.target == ops.aten.addmm.default
117        ):
118            with graph.inserting_after(node):
119                if node.target == ops.aten.addmm.default:
120                    weight_t_node = node.args[2]
121                    if weight_t_node.target not in [
122                        ops.aten.t_copy.default,
123                        ops.aten.permute_copy.default,
124                    ]:
125                        # Skip this node as it appears to be a standalone `addmm`
126                        continue
127                    weight_node = weight_t_node.args[0]
128                    args = (node.args[1], weight_node, node.args[0])
129                    linear_node = graph.create_node(
130                        "call_function", ops.aten.linear.default, args
131                    )
132                    node.replace_all_uses_with(linear_node)
133                    output_val = linear_node.target(  # pyre-fixme[29]
134                        args[0].meta["val"], args[1].meta["val"], args[2].meta["val"]
135                    )
136                else:
137                    weight_t_node = node.args[1]
138                    if weight_t_node.target not in [
139                        ops.aten.t_copy.default,
140                        ops.aten.permute_copy.default,
141                    ]:
142                        # Skip this node as it appears to be a standalone `mm`
143                        continue
144                    weight_node = weight_t_node.args[0]
145                    args = (node.args[0], weight_node)
146                    linear_node = graph.create_node(
147                        "call_function", ops.aten.linear.default, args
148                    )
149                    node.replace_all_uses_with(linear_node)
150                    output_val = linear_node.target(  # pyre-fixme[29]
151                        args[0].meta["val"], args[1].meta["val"]
152                    )
153                linear_node.meta = node.meta
154                # Val contain in this meta and corresponding shape will not be accurate
155                # Sub
156                linear_node.meta["val"] = output_val
157    graph.eliminate_dead_code()
158    return graph
159
160
161def apply_addmm_mm_to_linear_transform(graph: torch.fx.Graph) -> torch.fx.Graph:
162    graph = replace_addmm_mm_with_linear(graph)
163    graph = replace_linear_view_copy_input_output(graph)
164    return graph
165
166
167class AddmmToLinearTransform(ExportPass):
168    def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
169        graph_module.graph = apply_addmm_mm_to_linear_transform(graph_module.graph)
170        return PassResult(graph_module, True)
171