1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import logging 8import operator 9from collections import defaultdict 10from functools import lru_cache 11from typing import Dict, Iterable, List, Optional, Set, Tuple, Union 12 13import torch 14from executorch.exir.backend.backend_details import ExportedProgram 15from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import ( 16 duplicate_constant_node, 17) 18from executorch.exir.common import setting_python_recursive_limit 19from executorch.exir.delegate import executorch_call_delegate 20from executorch.exir.dialects._ops import ops as exir_ops 21 22from executorch.exir.lowered_backend_module import create_submodule_from_nodes 23from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param 24from torch.fx.node import Node 25from torch.fx.passes.utils.source_matcher_utils import SourcePartition 26 27T_QuantPerTensor = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default 28T_DQuantPerTensor = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default 29 30 31# NB: Set this to None to handle validation from MobileBert 32@lru_cache(maxsize=None) 33def is_same_node( 34 node_left: Iterable[torch.fx.Node], 35 node_right: Iterable[torch.fx.Node], 36) -> bool: 37 # two nodes are the same if they have the same target and op 38 # same for their args 39 if isinstance(node_left, torch.fx.Node) and isinstance(node_right, torch.fx.Node): 40 if not ( 41 (node_left.target == node_right.target) 42 and (node_left.op == node_right.op) 43 and (len(node_left.all_input_nodes) == len(node_right.all_input_nodes)) 44 and all( 45 is_same_node(arg_left, arg_right) 46 for arg_left, arg_right in zip( 47 node_left.all_input_nodes, node_right.all_input_nodes 48 ) 49 ) 50 ): 51 return False 52 else: 53 if len(list(node_left)) != len(list(node_right)): 54 return False 55 for n_left, n_right in zip(node_left, node_right): 56 if not is_same_node(n_left, n_right): 57 return False 58 return True 59 60 61def is_identical_graph( 62 graph_left: torch.fx.GraphModule, graph_right: torch.fx.GraphModule 63) -> bool: 64 # two graph are the same if they have the same nodes and op. The order of nodes also 65 # matters in this function is more strict. Two graph are not considered as the same 66 # if the topological order of the nodes is the same in this function but the order of nodes 67 # is not the same. 68 if len(list(graph_left.graph.nodes)) != len(list(graph_right.graph.nodes)): 69 return False 70 with setting_python_recursive_limit(30000): 71 for node_left, node_right in zip( 72 graph_left.graph.nodes, graph_right.graph.nodes 73 ): 74 if not (is_same_node(node_left, node_right)): 75 return False 76 return True 77 78 79def remove_first_quant_and_last_dequant( 80 graph_module: torch.fx.GraphModule, 81) -> None: 82 for node in graph_module.graph.nodes: 83 if node.target == T_QuantPerTensor: 84 if node.args[0].op == "placeholder": 85 node_users = list(node.users.keys()) 86 for dequant_node in node_users: 87 # point the dequant arg to the placeholder 88 dequant_node.args = (node.args[0],) + dequant_node.args[1:] 89 elif node.target == T_DQuantPerTensor: 90 node_users = list(node.users.keys()) 91 if node_users[0].op == "output": 92 # point the output arg to the quant node 93 output_node = node_users[0] 94 output_node.args = ([node.args[0]],) 95 # Remove the quant/dequant nodes as they don't have users 96 graph_module.graph.eliminate_dead_code() 97 graph_module.recompile() 98 99 100# TODO - use edge ops 101def replace_quantized_partition_with_op( 102 graph_module: torch.fx.GraphModule, 103 partition: SourcePartition, 104 replacement_op: torch._ops.OpOverloadPacket, 105) -> Tuple[torch.fx.Node, List[torch.fx.Node], List[torch.fx.Node]]: 106 """ 107 Replaces partition with the op specified by replacement_op. It's also expected that 108 the nodes contained in partition are sourced from a quantized module as this function 109 searches for the quantization pattern to consume along with the nodes in the partition, 110 to be then replaced by replacement_op. 111 112 Args: 113 graph_module: The graph module from which this partition was sourced. 114 partition: Partition to be replaced. 115 replacement_op: The op to replace paritition with. 116 Returns: 117 Tuple: First element in the tuple is the new replaced module. The second and third 118 node lists in the returned tuple consist of the dq and q nodes that were consumed 119 along with this partition to be replaced by the replacement_op. 120 """ 121 122 dequant_nodes = [] 123 quant_nodes = [] 124 input_nodes = [] 125 output_nodes = [] 126 127 partition_nodes = [node for node in partition.nodes if node not in partition.params] 128 129 # We recreate our input nodes and output nodes list instead of using partition.input_nodes 130 # and partition.output_nodes as the ordering of the nodes in those lists is not deterministic, 131 # whereas for the quant fusion pass we expect deterministic ordering. 132 for node in partition.nodes: 133 for arg in node.args: 134 if isinstance(arg, torch.fx.Node) and (arg not in partition.nodes): 135 input_nodes.append(arg) 136 137 for user in node.users.keys(): 138 if user not in partition.nodes: 139 output_nodes.append(node) 140 141 # Try to find all the dq nodes that are feeding into this module. 142 for node in input_nodes: 143 if node.target == T_DQuantPerTensor: 144 dequant_nodes += [node] 145 146 # Try to find all the q nodes that this module is feeding out into. 147 for node in output_nodes: 148 for user in node.users.keys(): 149 if user.target == T_QuantPerTensor: 150 quant_nodes += [user] 151 152 assert len(dequant_nodes) >= 1, "Dequant nodes missing in node list to be replaced." 153 assert len(quant_nodes) >= 1, "Quant nodes missing in node list to be replaced." 154 155 # After this, node list will essentially contain all the nodes in the 156 # dq->op->q pattern that we will want to replace with a custom backend op. 157 node_list = dequant_nodes + partition_nodes + quant_nodes 158 159 submodule, call_module_node = create_submodule_from_nodes( 160 graph_module, node_list, "to_be_replaced", skip_legalize_graph=True 161 ) 162 163 # Update the replaced op so that we have all the latest args and kwargs. 164 with graph_module.graph.inserting_before(call_module_node): 165 replaced_op = graph_module.graph.call_function( 166 replacement_op, 167 call_module_node.args, 168 kwargs=call_module_node.kwargs, 169 ) 170 call_module_node.replace_all_uses_with(replaced_op) 171 graph_module.graph.erase_node(call_module_node) 172 replaced_op.meta = call_module_node.meta 173 graph_module.recompile() 174 175 return (replaced_op, dequant_nodes, quant_nodes) 176 177 178def _assign_new_tag( 179 tagged_exported_program: ExportedProgram, 180 copied_nodes: Set[str], 181): 182 """ 183 Assign new tag to the copied nodes. 184 185 Before the pass 186 constant_0 (tag_10) ------------------> op_b (tag_10) 187 constant_0_copy (tag_10) -------------> op_a (tag_11) 188 189 After the pass 190 constant_0 (tag_10) ------------------> op_b (tag_10) 191 constant_0_copy (tag_11) -------------> op_a (tag_11) 192 193 """ 194 for node in tagged_exported_program.graph.nodes: 195 if node.op == "placeholder": 196 if node.name in copied_nodes: 197 users_tag = set() 198 for user in node.users: 199 users_tag.add(user.meta.get("delegation_tag", None)) 200 # Assign the tag to the copy constant node the same as their users. 201 if len(users_tag) == 1: 202 node.meta["delegation_tag"] = users_tag.pop() 203 204 205def _maybe_duplicate_constant_nodes( 206 tagged_exported_program: ExportedProgram, 207 tag: str, 208) -> None: 209 """ 210 If the constants node is shared by different tagged nodes, like 211 constant_0 ----> op_b (tag_10) 212 |-------------> op_a (tag_11) 213 214 we make default as constant_0 is duplicated to constant_0_1, constant_0_2, unless the node is tagged with "no_copy" 215 constant_0 ------------------> op_b (tag_10) 216 constant_0_copy -------------> op_a (tag_11) 217 218 backend can estimate how much they want to duplicate the constant node, either error out or default to duplicate 219 """ 220 candidate_nodes = set() 221 for node in tagged_exported_program.graph.nodes: 222 if node.meta.get("delegation_tag", "") == tag: 223 if node.op == "placeholder": 224 for user in node.users: 225 users_tag = user.meta.get("delegation_tag", None) 226 if users_tag != tag: 227 # If the node is tagged with "no_copy", we stop duplicating it and throw an error 228 if node.meta.get("no_copy", False): 229 raise RuntimeError( 230 f"constant data node ({node}) is tagged with ({tag}) but has user ({user}) which has tag ({users_tag})" 231 ) 232 else: 233 candidate_nodes.add(node.name) 234 copied_nodes = set() 235 for candidate_node in candidate_nodes: 236 # Both tagged exported program and the owning program need to go through the same duplication pass 237 copied_nodes = copied_nodes.union( 238 duplicate_constant_node(tagged_exported_program, candidate_node) 239 ) 240 candidate_node_with_copies = candidate_nodes.union(copied_nodes) 241 _assign_new_tag(tagged_exported_program, candidate_node_with_copies) 242 243 244def _get_item_from_executorch_call_delegate(node: torch.fx.Node) -> bool: 245 """ 246 Check if the node is the getitem followed by executorch_call_delegate node. These getitems node 247 are just for getting the result from delegate because the input/output to delegates are flattened 248 """ 249 return ( 250 node.target == operator.getitem 251 and len(node.args) == 2 252 and node.args[0].target == executorch_call_delegate # pyre-ignore 253 and isinstance(node.args[1], int) 254 ) 255 256 257def get_non_lowered_nodes(graph: torch.fx.Graph) -> List[torch.fx.Node]: 258 """ 259 Returns a list of non lowered nodes in the graph module. 260 """ 261 return [ 262 node 263 for node in graph.nodes 264 if node.op == "call_function" 265 and node.target != executorch_call_delegate 266 and (not _get_item_from_executorch_call_delegate(node)) 267 ] 268 269 270def get_delegates(graph: torch.fx.Graph) -> List[torch.fx.Node]: 271 """ 272 Returns the list of delegates from the graph. 273 """ 274 return [ 275 node 276 for node in graph.nodes 277 if node.op == "get_attr" and node.name.startswith("lowered_module_") 278 ] 279 280 281def print_delegated_graph(graph_module: torch.fx.GraphModule) -> None: 282 """ 283 Print the formatted graph string. 284 """ 285 print(format_delegated_graph(graph_module)) 286 287 288def format_delegated_graph(graph_module: torch.fx.GraphModule) -> str: 289 """ 290 Return the formatted graph string of including lowered_module (both backend id and original graph) together with the graph module. Example output: 291 graph(): 292 %arg0_1 : [num_users=2] = placeholder[target=arg0_1] 293 %arg1_1 : [num_users=2] = placeholder[target=arg1_1] 294 %arg2_1 : [num_users=2] = placeholder[target=arg2_1] 295 %lowered_module_0 : [num_users=1] = get_attr[target=lowered_module_0] 296 backend_id: BackendWithCompilerDemo 297 lowered graph(): 298 %arg0_1 : [num_users=1] = placeholder[target=arg0_1] 299 %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 300 %arg2_1 : [num_users=1] = placeholder[target=arg2_1] 301 %aten_mm_default : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%arg0_1, %arg1_1), kwargs = {}) 302 %aten_add_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default, %arg2_1), kwargs = {}) 303 return [aten_add_tensor] 304 %executorch_call_delegate : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_0, %arg0_1, %arg1_1, %arg2_1), kwargs = {}) 305 %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate, 0), kwargs = {}) 306 %aten_sub_tensor : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.sub.Tensor](args = (%getitem, %arg0_1), kwargs = {}) 307 %lowered_module_1 : [num_users=1] = get_attr[target=lowered_module_1] 308 backend_id: BackendWithCompilerDemo 309 lowered graph(): 310 %aten_sub_tensor : [num_users=1] = placeholder[target=aten_sub_tensor] 311 %arg1_1 : [num_users=1] = placeholder[target=arg1_1] 312 %arg2_1 : [num_users=1] = placeholder[target=arg2_1] 313 %aten_mm_default_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mm.default](args = (%aten_sub_tensor, %arg1_1), kwargs = {}) 314 %aten_add_tensor_1 : [num_users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.add.Tensor](args = (%aten_mm_default_1, %arg2_1), kwargs = {}) 315 return [aten_add_tensor_1] 316 %executorch_call_delegate_1 : [num_users=1] = call_function[target=torch.ops.higher_order.executorch_call_delegate](args = (%lowered_module_1, %aten_sub_tensor, %arg1_1, %arg2_1), kwargs = {}) 317 %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%executorch_call_delegate_1, 0), kwargs = {}) 318 return [getitem_1] 319 """ 320 lowered_module_dict = { 321 node.name: getattr(graph_module, node.name) 322 for node in graph_module.graph.nodes 323 if node.op == "get_attr" and node.name.startswith("lowered_module_") 324 } 325 indent = " " 326 graph_format_str = "graph():\n" 327 for node in graph_module.graph.nodes: 328 graph_format_str += f"{indent}{node.format_node()}\n" 329 if node.op == "get_attr" and node.name.startswith("lowered_module_"): 330 lowered_module = lowered_module_dict[node.name] 331 graph_format_str += f"{indent * 2}backend_id: {lowered_module.backend_id}\n" 332 graph_format_str += f"{indent * 2}lowered graph():\n" 333 for node_in_lowered_module in lowered_module.original_module.graph.nodes: 334 graph_format_str += ( 335 f"{indent * 3}{node_in_lowered_module.format_node()}\n" 336 ) 337 return graph_format_str 338 339 340def tag_constant_data(edge_program: ExportedProgram) -> None: 341 """ 342 Util function for partitioners. This function tags the const/param/buffers nodes 343 whose users all belong within the same partition. This should be called after tagging all other nodes. 344 Any const/param/buffer which is used as input to a subgraph, will be tagged with the same tag as that 345 subgraph. Throw error when const/param/buffers is used across different partitions. That is the 346 underlying data will be owned by multiple delegates. 347 """ 348 mutated_buffer = set() 349 for node in edge_program.graph.nodes: 350 if node.op == "placeholder" and ( 351 is_param(edge_program, node) 352 or is_buffer(edge_program, node) 353 or is_lifted_tensor_constant(edge_program, node) 354 ): 355 for node_user in node.users: 356 if node_user.name in edge_program.graph_signature.buffers_to_mutate: 357 logging.info( 358 "The buffer node is a mutated buffer node, which is not constant." 359 ) 360 mutated_buffer.add(node) 361 362 for node in edge_program.graph.nodes: 363 # go through const/param/buffer nodes, if all users of const/param/buffer nodes are partitioned then partition 364 if node.op == "placeholder" and ( 365 is_param(edge_program, node) 366 or is_buffer(edge_program, node) 367 or is_lifted_tensor_constant(edge_program, node) 368 ): 369 if node not in mutated_buffer: 370 user_tags = set() 371 for user in node.users: 372 user_tag = user.meta.get("delegation_tag", None) 373 if user_tag is not None: 374 user_tags.add(user_tag) 375 if len(user_tags) > 1: 376 logging.info( 377 f"The data node is used across multiple partitions, including {user_tags}. " 378 "If the data is too large and it's not preferred to copy, please tag the " 379 "constant node like node.['no_copy'] = True and they won't be copied." 380 ) 381 # tag the data node with the same tag as the last user 382 if len(user_tags) > 0: 383 node.meta["delegation_tag"] = user_tags.pop() 384 385 386def tag_mutated_buffer(edge_program: ExportedProgram) -> None: 387 """ 388 Util function for partitioners. This function tags the mutated buffer nodes 389 whose users all belong within the same partition. This should be called after tagging all other nodes. 390 Any buffer which is used as input to a subgraph, will be tagged with the same tag as that 391 subgraph. Throw error when buffers is used across different partitions. That is the 392 underlying data will be owned by multiple delegates. 393 """ 394 for node in edge_program.graph.nodes: 395 # Determine whether this node is a mutated buffer 396 is_mutated_buffer_node = False 397 if node.op == "placeholder" and is_buffer(edge_program, node): 398 for node_user in node.users: 399 if node_user.name in edge_program.graph_signature.buffers_to_mutate: 400 is_mutated_buffer_node = True 401 break 402 # This node is mutated buffer, tag it 403 if is_mutated_buffer_node: 404 user_tags = set() 405 for user in node.users: 406 user_tag = user.meta.get("delegation_tag", None) 407 if user_tag is not None: 408 user_tags.add(user_tag) 409 if len(user_tags) > 1: 410 logging.info( 411 f"The data node is used across multiple partitions, including {user_tags}. " 412 "If the data is too large and it's not preferred to copy, please tag the " 413 "constant node like node.['no_copy'] = True and they won't be copied." 414 ) 415 # tag the data node with the same tag as the last user 416 if len(user_tags) > 0: 417 node.meta["delegation_tag"] = user_tags.pop() 418 419 420# TODO - style: use templated types 421class DelegateMappingBuilder: 422 """ 423 Profiling helper class for building Delegate Mappings. 424 Delegate Mappings are mappings from delegate debug identifiers to node 425 debug handles. Specifically this is used to log within backend delegates 426 427 Args: 428 generated_identifiers (bool, optional): Whether identifier keys are 429 generated automatically. Defaults to False. 430 """ 431 432 def __init__(self, generated_identifiers: bool = False): 433 self._generated_identifiers = generated_identifiers 434 435 # Note that the internal struct has a Set value, while the getter 436 # function returns the values as a tuple 437 self._debug_handle_map: Union[Dict[int, Set[int]], Dict[str, Set[int]]] = ( 438 defaultdict(set) 439 ) 440 self._next_index: int = 0 441 442 def get_delegate_mapping( 443 self, 444 ) -> Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]: 445 """ 446 Returns: 447 Union[Dict[int, Tuple[int]], Dict[str, Tuple[int]]]: 448 A map of delegate debug identifier to a list of debug handles 449 The keys (identifier) are either integers or strings 450 The values are a sorted tuple of integer debug handles 451 """ 452 # pyre-ignore Warning between Union[Dict[K, V], Dict[K2, V]] vs Dict[Union[K, K2], V] 453 return {k: tuple(sorted(v)) for k, v in self._debug_handle_map.items()} 454 455 def insert_delegate_mapping_entry( 456 self, 457 nodes: Optional[Union[Node, List[Node]]] = None, 458 handles: Optional[Union[int, List[Optional[int]]]] = None, 459 identifier: Optional[Union[int, str]] = None, 460 ) -> Union[int, str]: 461 """ 462 Add a new delegate mapping entry 463 464 If self._generated_identifiers = False: 465 - A new identifier must be provided, else an exception is thrown 466 467 If self._generated_identifiers = True: 468 - New identifiers are generated incrementally, 0 indexed 469 - Identifiers cannot be manually provided, else an exception is thrown 470 471 Args: 472 nodes (Union[Node, List[Node]]): A (list of) Node(s) 473 handles (Union[int, List[Optional[int]]]): A (list of) debug handle(s) 474 identifier (Optional[Union[int, str]]): 475 Debug identifier corresponding to the Node(s) 476 477 Note: Exactly one of nodes and handles must be provided 478 Note: If a debug handle is missing or None, it is skipped 479 480 Returns: 481 Union[int, str]: 482 Delegate debug identifier inserted 483 """ 484 485 # Check for manual addition of identifier (with generated identifiers enabled) 486 if self._generated_identifiers and identifier is not None: 487 raise Exception( 488 f"Builders using generated identifiers can't manually add identifiers: {identifier}. Failed to add or update entry" 489 ) 490 491 if identifier is not None and identifier in self._debug_handle_map: 492 raise Exception( 493 "This delegate debug identifier was already inserted. Duplicate delegate debug identifiers are not allowed." 494 ) 495 496 # Check for exactly one of nodes and handles being populated 497 if not ((nodes is not None) ^ (handles is not None)): 498 raise Exception( 499 "Only one of nodes or handles must be provided. Either both were provided or neither were provided. Failed to add or update entry." 500 ) 501 502 # Resolve Identifier 503 if identifier is None: 504 if self._generated_identifiers: 505 identifier = self._next_index 506 self._next_index += 1 507 else: 508 raise Exception( 509 "No identifier provided. Failed to add or update entry." 510 ) 511 512 # Collect debug handles 513 if nodes is not None: 514 new_debug_handles = { 515 node.meta.get("debug_handle") 516 for node in (nodes if isinstance(nodes, List) else [nodes]) 517 } 518 else: 519 new_debug_handles = ( 520 handles if isinstance(handles, (tuple, List)) else [handles] 521 ) 522 523 # Filter for empty debug handles 524 filtered_debug_handles = { 525 handle for handle in new_debug_handles if handle is not None 526 } 527 if len(filtered_debug_handles) == 0: 528 raise Exception("No valid debug handles found. Failed to add entry.") 529 530 # pyre-ignore Warning from Union[int, st] keys 531 self._debug_handle_map[identifier] = filtered_debug_handles 532 return identifier 533 534 535class WhyNoPartition: 536 """ 537 Simple helper class for partitioners to log why a node was not lowered. 538 539 Example usage: 540 541 # In your backend partitioner file(s) 542 why = WhyNoPartition(logger=your_backend_logger) 543 544 # hypothetical function that checks if a node can be lowered 545 if not can_be_lowered(node): 546 why(node, "This node was not lowered because ...") 547 """ 548 549 def __init__(self, logger: logging.Logger): 550 self.logger = logger 551 self.node: Optional[torch.fx.Node] = None 552 self.reason: str = "" 553 554 def __call__(self, node: torch.fx.Node, reason: str) -> None: 555 self.node = node 556 self.reason = reason 557 self.logger.debug(self) 558 559 def __str__(self) -> str: 560 return f"WhyNoPartition: Node {self.node} was not partitioned because {self.reason}." 561