1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import _operator 7from typing import List, Tuple 8 9import torch 10 11from executorch.backends.qualcomm.builders.utils import is_parameter 12from executorch.backends.qualcomm.utils.constants import ( 13 QCOM_AXIS_ORDER, 14 QCOM_INSERTED_PERMUTE, 15 QCOM_LAYOUT_CHANGE, 16 QCOM_QUANT_ATTRS, 17 QCOM_REQUANTIZE, 18) 19from executorch.exir.dialects._ops import ops as exir_ops 20from executorch.exir.pass_base import ExportPass, PassResult 21from executorch.exir.sym_util import eval_shape 22 23from .utils import dq_ops, q_ops 24 25 26class LayoutTransform(ExportPass): 27 """ 28 QNN delegate requires channel last layout format, this pass aims to 29 help generate the correct transformation by inserting fewest ammount of 30 'permute' operators in the graph. 31 """ 32 33 layout_sensitive_ops = { 34 exir_ops.edge.aten.avg_pool2d.default, 35 exir_ops.edge.aten.convolution.default, 36 exir_ops.edge.aten.max_pool2d_with_indices.default, 37 exir_ops.edge.aten._native_batch_norm_legit_no_training.default, 38 exir_ops.edge.aten.native_group_norm.default, 39 exir_ops.edge.aten.pixel_shuffle.default, 40 exir_ops.edge.aten.pixel_unshuffle.default, 41 exir_ops.edge.aten.upsample_bilinear2d.default, 42 exir_ops.edge.aten.upsample_nearest2d.default, 43 } 44 45 layout_agnostic_ops = { 46 exir_ops.edge.aten.add.Tensor, 47 exir_ops.edge.aten.bmm.default, 48 exir_ops.edge.aten.cat.default, 49 exir_ops.edge.aten.ceil.default, 50 exir_ops.edge.aten.clamp.default, 51 exir_ops.edge.aten.constant_pad_nd.default, 52 exir_ops.edge.aten.div.Tensor, 53 exir_ops.edge.aten.full.default, 54 exir_ops.edge.aten.gelu.default, 55 exir_ops.edge.aten.hardswish.default, 56 exir_ops.edge.aten.hardsigmoid.default, 57 exir_ops.edge.aten.hardtanh.default, 58 exir_ops.edge.aten.leaky_relu.default, 59 exir_ops.edge.aten.linear.default, 60 exir_ops.edge.aten._log_softmax.default, 61 exir_ops.edge.aten.mean.dim, 62 exir_ops.edge.aten.mul.Tensor, 63 exir_ops.edge.aten.pow.Tensor_Scalar, 64 exir_ops.edge.aten.prelu.default, 65 exir_ops.edge.aten.relu.default, 66 exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis. 67 exir_ops.edge.aten.sqrt.default, 68 exir_ops.edge.aten.sub.Tensor, 69 exir_ops.edge.aten.sum.dim_IntList, 70 exir_ops.edge.aten.topk.default, 71 exir_ops.edge.aten._to_copy.default, 72 exir_ops.edge.aten.split_with_sizes.default, 73 *q_ops, 74 *dq_ops, 75 _operator.getitem, 76 } 77 78 layout_type = { 79 1: ("N", "N"), 80 2: ("NC", "NC"), 81 3: ("NCW", "NWC"), 82 4: ("NCHW", "NHWC"), 83 5: ("NCDHW", "NDHWC"), 84 } 85 86 @classmethod 87 def get_axis_order(cls, size: List[int], reverse=False) -> Tuple[int]: 88 old_layout, new_layout = cls.layout_type[len(size)] 89 if reverse: 90 old_layout, new_layout = new_layout, old_layout 91 return tuple(old_layout.find(x) for x in new_layout) 92 93 def __init__( 94 self, edge_program: torch.export.ExportedProgram, insert_permute=False 95 ): 96 super(LayoutTransform, self).__init__() 97 self.edge_program = edge_program 98 self.insert_permute = insert_permute 99 self.qdq_opset = {*q_ops, *dq_ops} 100 self.transformed_tag = QCOM_AXIS_ORDER 101 102 def mark_as_transformed(self, node: torch.fx.Node) -> None: 103 if isinstance(node.meta["val"], (tuple, list)): 104 getitem_node = list(node.users.keys())[0] 105 if getitem_node.target.__name__ != "getitem": 106 raise AssertionError( 107 "Expected node's user to be getitem, " 108 f"got {getitem_node.target.__name__}" 109 ) 110 index = getitem_node.args[1] 111 node.meta[self.transformed_tag] = self.get_axis_order( 112 eval_shape(node.meta["val"][index].shape) 113 ) 114 else: 115 node.meta[self.transformed_tag] = self.get_axis_order( 116 eval_shape(node.meta["val"].shape) 117 ) 118 119 def is_transformed_node(self, node: torch.fx.Node) -> bool: 120 if not hasattr(node, "meta"): 121 return False 122 return self.transformed_tag in node.meta 123 124 def is_layout_sensitive(self, node: torch.fx.Node) -> bool: 125 return node.target in self.layout_sensitive_ops 126 127 def is_layout_agnostic(self, node: torch.fx.Node) -> bool: 128 if node.target in [ 129 exir_ops.edge.aten.mean.dim, 130 exir_ops.edge.aten.sum.dim_IntList, 131 ]: 132 # if dimemsion is not kept, we'll have no clue how to do layout transform 133 if len(node.args) < 3 or not node.args[2]: 134 return False 135 if node.target in self.qdq_opset: 136 return QCOM_REQUANTIZE in node.meta 137 return node.target in self.layout_agnostic_ops 138 139 def is_edge_condition(self, node): 140 if not isinstance(node, torch.fx.Node): 141 return True 142 143 if any( 144 [ 145 self.is_transformed_node(node), 146 node.op == "get_attr", 147 ( 148 node.target == exir_ops.edge.aten.permute_copy.default 149 and node.meta.get(QCOM_INSERTED_PERMUTE, False) 150 ), 151 ( 152 node.op != "output" 153 and not isinstance(node.meta["val"], (tuple, list)) 154 and len(node.meta["val"].shape) == 0 155 ), 156 is_parameter(node, self.edge_program), 157 ] 158 ): 159 return True 160 161 return False 162 163 def insert_node(self, graph_module, node, revert_layout: bool) -> None: 164 if not self.insert_permute: 165 return 166 with graph_module.graph.inserting_after(node): 167 users = node.users.copy() 168 if isinstance(node.meta["val"], tuple): 169 getitem_node = list(node.users.keys())[0] 170 if getitem_node.target.__name__ != "getitem": 171 raise AssertionError( 172 f"Expected bn node's user to be getitem, got {getitem_node.target.__name__}" 173 ) 174 index = getitem_node.args[1] 175 tensor = node.meta["val"][index] 176 else: 177 tensor = node.meta["val"] 178 179 permute = self.create_call_function_node( 180 graph_module, 181 exir_ops.edge.aten.permute_copy.default, 182 ( 183 node, 184 self.get_axis_order(eval_shape(tensor.shape), revert_layout), 185 ), 186 ) 187 permute.meta["val"] = tensor 188 permute.meta[QCOM_QUANT_ATTRS] = node.meta.get(QCOM_QUANT_ATTRS) 189 # we need this to check the annotation boundary 190 permute.meta[QCOM_INSERTED_PERMUTE] = True 191 192 # this is the case when residual connection happened: 193 # e.g. consider following graph 194 # x --> permute --> layer_norm --> permute --> conv2d --> add 195 # └-------------------------------------┙ 196 # we should have premute node to be correctly inserted as: 197 # x --> permute --> layer_norm --> permute --> qnn_permute --> conv2d --> add 198 # └--------------------------------------> qnn_premute -┙ 199 # i.e. insert permute by condition between user and current node 200 # if there are multiple users included 201 is_node_transformed = self.is_transformed_node(node) 202 for user in users: 203 is_user_transformed = ( 204 self.is_transformed_node(user) or QCOM_LAYOUT_CHANGE in user.meta 205 ) 206 # insert permute only in exclusive condition 207 if is_node_transformed != is_user_transformed: 208 user.replace_input_with(node, permute) 209 210 def create_call_function_node( 211 self, 212 graph_module: torch.fx.GraphModule, 213 target: torch.fx.node.Target, 214 args: Tuple[torch.fx.node.Argument, ...], 215 ): 216 return graph_module.graph.create_node( 217 "call_function", 218 target=target, 219 args=args, 220 ) 221 222 def traverse(self, node: torch.fx.Node, graph_module: torch.fx.GraphModule) -> None: 223 for arg in node.args: 224 if isinstance(arg, list): 225 for arg_node in arg: 226 self.annotate_layout(arg_node, graph_module, revert_layout=False) 227 else: 228 self.annotate_layout(arg, graph_module, revert_layout=False) 229 230 node_users = set(node.users.keys()) 231 for user in node_users: 232 self.annotate_layout(user, graph_module, revert_layout=True) 233 234 def annotate_layout( 235 self, node: torch.fx.Node, graph_module: torch.fx.GraphModule, revert_layout 236 ) -> None: 237 238 if self.is_edge_condition(node): 239 return 240 elif self.is_layout_agnostic(node) or self.is_layout_sensitive(node): 241 self.mark_as_transformed(node) 242 self.traverse(node, graph_module) 243 else: 244 245 def check_arg(arg): 246 if self.is_transformed_node(arg): 247 self.insert_node(graph_module, arg, revert_layout=revert_layout) 248 249 if not revert_layout: 250 self.insert_node(graph_module, node, revert_layout=revert_layout) 251 else: 252 for args in node.args: 253 if isinstance(args, torch.fx.immutable_collections.immutable_list): 254 for arg in args: 255 check_arg(arg) 256 else: 257 check_arg(args) 258 259 def call(self, graph_module: torch.fx.GraphModule): 260 graph = graph_module.graph 261 sensitive_nodes = [ 262 node for node in graph.nodes if self.is_layout_sensitive(node) 263 ] 264 # perform first run traversal for identifying nodes subjected to layout changes 265 if self.insert_permute: 266 self.insert_permute, self.transformed_tag = False, QCOM_LAYOUT_CHANGE 267 for node in sensitive_nodes: 268 if not self.is_transformed_node(node): 269 self.mark_as_transformed(node) 270 self.traverse(node, graph_module) 271 self.insert_permute, self.transformed_tag = True, QCOM_AXIS_ORDER 272 273 for node in sensitive_nodes: 274 if not self.is_transformed_node(node): 275 self.mark_as_transformed(node) 276 self.traverse(node, graph_module) 277 278 graph_module.recompile() 279 if not self.insert_permute: 280 graph_module = super().call(graph_module).graph_module 281 return PassResult(graph_module, True) 282