1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.exir.dialects._ops import ops as exir_ops 9from executorch.exir.pass_base import ExportPass, PassResult 10 11from executorch.exir.sym_util import eval_shape, eval_shape_upper_bound 12 13 14_int64_max_dim_val = torch.iinfo(torch.int64).max - 1 15 16 17def get_shape(input_node: torch.fx.Node): 18 """ 19 If shape is symbolic then evaluate shape, otherwise if it has upperbound 20 shape, then return upperbound shape. 21 Note that we must check for upperbound because by default upperbound is int64_max 22 """ 23 input_val = input_node.meta["val"] 24 upper_bound_shape = eval_shape_upper_bound(input_val.shape) 25 for i in range(len(input_val.shape)): 26 # Unbounded shape get int64 max values assigned to it. 27 # This is just hacking around it when export with dynamic shape 28 # does not use constraint api but instead just traces the 29 # modelw with tensors of the max size 30 if upper_bound_shape[i] >= _int64_max_dim_val: 31 return eval_shape(input_val.shape) 32 return upper_bound_shape 33 34 35def get_dqlinear_input(node: torch.fx.Node): 36 ops = exir_ops.edge 37 node_to_backtrack = node 38 # First find the activation input 39 # Then trace it backwards through all view copies 40 # Until you find dequant node. 41 # if any of the nodes, during backtracking, is not view_copy 42 # then break 43 while node_to_backtrack.op != "placeholder": 44 if ( 45 node_to_backtrack.op == "call_function" 46 and node_to_backtrack.target 47 == ops.quantized_decomposed.dequantize_per_tensor.tensor 48 ): 49 return node_to_backtrack 50 if ( 51 node_to_backtrack.op == "call_function" 52 and node_to_backtrack.target == ops.aten.view_copy.default 53 ): 54 node_to_backtrack = node_to_backtrack.args[0] 55 else: 56 return None 57 return None 58 59 60def replace_linear_view_copy_input_output(graph: torch.fx.Graph) -> torch.fx.Graph: 61 """ 62 Replaces pattern: x -> view_copy -> view_copy -> linear -> view_copy -> y 63 with 64 x -> linear -> y 65 Linear nodes can handle input tensor with > 2 dimensions. 66 """ 67 ops = exir_ops.edge 68 for node in graph.nodes: 69 if node.op == "call_function" and (node.target == ops.aten.linear.default): 70 input_node = node.args[0] 71 dqlinear_input = get_dqlinear_input(input_node) 72 if dqlinear_input is not None and dqlinear_input != input_node: 73 if len(input_node.args[0].users) == 1: 74 input_node.replace_all_uses_with(dqlinear_input) 75 else: 76 print( 77 f"{input_node} has more than one user. Users: {input_node.users}" 78 ) 79 if len(node.users) == 1: 80 users = list(node.users) 81 maybe_view_copy = users[0] 82 if maybe_view_copy.op == "call_function" and ( 83 maybe_view_copy.target == ops.aten.view_copy.default 84 ): 85 # Must update the input node since replaced the original node 86 input_node = node.args[0] 87 input_shape = list(get_shape(input_node)) 88 weight_node = node.args[1] 89 if "val" not in weight_node.meta: 90 raise ValueError(f"Val not found meta of node {weight_node}") 91 weight_val = weight_node.meta["val"] 92 output_channels = weight_val.shape[0] 93 output_shape = input_shape 94 output_shape[-1] = output_channels 95 view_copy_out_shape = list(get_shape(maybe_view_copy)) 96 if output_shape == view_copy_out_shape: 97 maybe_view_copy.replace_all_uses_with(node) 98 graph.eliminate_dead_code() 99 return graph 100 101 102def replace_addmm_mm_with_linear(graph: torch.fx.Graph) -> torch.fx.Graph: 103 """ 104 Replace calls to addmm/mm with linear node 105 Reason is that it simplifies the downstream logic of lowering to just linear node. 106 Furthermore it also removes various view_copy nodes. These nodes have been absorbed 107 by delegated by ignoring them entirely. 108 Furthermore, removing view_copy nodes has the advantage of not having to match 109 against view copies which simplifies the pattern that has to be matched. 110 Simplified patterns will be less brittle since symbolic ints and sizes creeping into 111 the graph was making them harder to match. 112 """ 113 ops = exir_ops.edge 114 for node in graph.nodes: 115 if node.op == "call_function" and ( 116 node.target == ops.aten.mm.default or node.target == ops.aten.addmm.default 117 ): 118 with graph.inserting_after(node): 119 if node.target == ops.aten.addmm.default: 120 weight_t_node = node.args[2] 121 if weight_t_node.target not in [ 122 ops.aten.t_copy.default, 123 ops.aten.permute_copy.default, 124 ]: 125 # Skip this node as it appears to be a standalone `addmm` 126 continue 127 weight_node = weight_t_node.args[0] 128 args = (node.args[1], weight_node, node.args[0]) 129 linear_node = graph.create_node( 130 "call_function", ops.aten.linear.default, args 131 ) 132 node.replace_all_uses_with(linear_node) 133 output_val = linear_node.target( # pyre-fixme[29] 134 args[0].meta["val"], args[1].meta["val"], args[2].meta["val"] 135 ) 136 else: 137 weight_t_node = node.args[1] 138 if weight_t_node.target not in [ 139 ops.aten.t_copy.default, 140 ops.aten.permute_copy.default, 141 ]: 142 # Skip this node as it appears to be a standalone `mm` 143 continue 144 weight_node = weight_t_node.args[0] 145 args = (node.args[0], weight_node) 146 linear_node = graph.create_node( 147 "call_function", ops.aten.linear.default, args 148 ) 149 node.replace_all_uses_with(linear_node) 150 output_val = linear_node.target( # pyre-fixme[29] 151 args[0].meta["val"], args[1].meta["val"] 152 ) 153 linear_node.meta = node.meta 154 # Val contain in this meta and corresponding shape will not be accurate 155 # Sub 156 linear_node.meta["val"] = output_val 157 graph.eliminate_dead_code() 158 return graph 159 160 161def apply_addmm_mm_to_linear_transform(graph: torch.fx.Graph) -> torch.fx.Graph: 162 graph = replace_addmm_mm_with_linear(graph) 163 graph = replace_linear_view_copy_input_output(graph) 164 return graph 165 166 167class AddmmToLinearTransform(ExportPass): 168 def call(self, graph_module: torch.fx.GraphModule) -> PassResult: 169 graph_module.graph = apply_addmm_mm_to_linear_transform(graph_module.graph) 170 return PassResult(graph_module, True) 171