1# mypy: allow-untyped-defs 2import collections 3import enum 4from typing import Any, Dict, List, Optional, Set, Tuple 5 6import torch 7from torch.ao.quantization import FakeQuantizeBase, ObserverBase 8from torch.ao.quantization.utils import getattr_from_fqn 9from torch.fx import GraphModule 10from torch.fx.graph import Graph, Node 11 12from .mappings import get_base_name_to_sets_of_related_ops, get_unmatchable_types_map 13from .ns_types import NSNodeTargetType, NSSubgraph 14from .pattern_utils import ( 15 end_node_matches_reversed_fusion, 16 get_reversed_fusions, 17 get_type_a_related_to_b, 18) 19 20 21toq = torch.ops.quantized 22 23 24def _get_output_nodes(g: Graph) -> List[Node]: 25 return [n for n in g.nodes if n.op == "output"] 26 27 28class _NSGraphMatchableSubgraphsIterator: 29 """ 30 Iterates through the graph of gm, starting with the output nodes 31 and continuing backwards. 32 1. Returns matchable subgraphs, in order. A subgraph is defined by 33 (start_node, end_node). 34 2. Skips over non-matchable subgraphs 35 """ 36 37 def __init__( 38 self, 39 gm: GraphModule, 40 non_matchable_functions: Set[NSNodeTargetType], 41 non_matchable_modules: Set[NSNodeTargetType], 42 non_matchable_methods: Set[NSNodeTargetType], 43 ): 44 self.gm: GraphModule = gm 45 self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions 46 self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules 47 self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods 48 self.seen_nodes: Set[Node] = set() 49 self.stack: List[Node] = [] 50 for start_node in _get_output_nodes(self.gm.graph): 51 self.stack.append(start_node) 52 53 def __iter__(self): 54 return self 55 56 def __next__(self) -> NSSubgraph: 57 """ 58 Returns the next matchable subgraph. 59 """ 60 while len(self.stack) > 0: 61 cur_end_node = self.stack.pop() 62 if cur_end_node in self.seen_nodes: 63 continue 64 65 # for subgraphs which are single nodes, start_node == end_node 66 # for subgraphs with more than one node, start node != end_node 67 cur_start_node = cur_end_node 68 # Subgraphs like linear-relu have the base node as the start node. 69 # Subgraphs like dequantize-linear-relu-to(torch.float16) have the 70 # base node as the second node. 71 # The cur_base_op_node var will move to the actual node during 72 # the fusion matching later in this code block. 73 cur_base_op_node = cur_end_node 74 75 # Check for potential fusions. For now, we are greedy 76 # and always skip all non-base nodes of a fusion. For example, 77 # if we match linear-relu backwards, we will always skip the 78 # relu node and attempt to match the linear node. This can 79 # be made configurable later if needed. 80 for _reverse_fusion_ops, base_op_idx in get_reversed_fusions(): 81 is_match = end_node_matches_reversed_fusion( 82 cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes 83 ) 84 if is_match: 85 # navigate to the base node 86 for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): 87 self.seen_nodes.add(cur_start_node) 88 # for now, assume that there are no other nodes 89 # which need to be added to the stack 90 cur_start_node = cur_start_node.args[0] # type: ignore[assignment] 91 # if the base op index matches the current node, set it 92 rev_base_op_idx = len(_reverse_fusion_ops) - 2 - base_op_idx 93 if rev_fusion_idx == rev_base_op_idx: 94 cur_base_op_node = cur_start_node 95 break 96 97 self.seen_nodes.add(cur_start_node) 98 # add args of previous nodes to stack 99 for arg in cur_start_node.all_input_nodes: 100 self._recursively_add_node_arg_to_stack(arg) 101 102 # skip unmatchable nodes 103 # note: this check is done on the start_node, i.e. 104 # if we are matching linear-relu in reverse, this would do the matchable 105 # check on the linear 106 if not self._is_matchable(cur_base_op_node): 107 continue 108 109 # If an observer or a fake_quant was not matched as a part of 110 # a pattern of multiple nodes, ignore it. One case where this is 111 # relevant is an observer on a graph input, which was added because 112 # it is necessary for the next node. 113 if cur_end_node.op == "call_module" and cur_start_node is cur_end_node: 114 maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target) # type: ignore[arg-type] 115 if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)): 116 continue 117 118 return NSSubgraph( 119 start_node=cur_start_node, 120 end_node=cur_end_node, 121 base_op_node=cur_base_op_node, 122 ) 123 124 raise StopIteration 125 126 def _recursively_add_node_arg_to_stack(self, arg: Any) -> None: 127 """ 128 Adds all of the nodes in this arg to the stack, properly navigating 129 through list, dicts and tuples. 130 """ 131 if isinstance(arg, Node): 132 self.stack.append(arg) 133 elif ( 134 isinstance(arg, torch.fx.immutable_collections.immutable_list) 135 or type(arg) is tuple 136 ): 137 for inner_arg in arg: 138 self._recursively_add_node_arg_to_stack(inner_arg) 139 elif isinstance(arg, torch.fx.immutable_collections.immutable_dict): 140 for value in arg.values(): 141 self._recursively_add_node_arg_to_stack(value) 142 143 def _is_matchable(self, node: Node) -> bool: 144 if node.op == "call_function": 145 return node.target not in self.non_matchable_functions 146 elif node.op == "call_module": 147 assert isinstance(node.target, str) 148 target_mod = getattr_from_fqn(self.gm, node.target) 149 return not any( 150 isinstance(target_mod, t) # type: ignore[arg-type] 151 for t in self.non_matchable_modules 152 ) 153 elif node.op == "call_method": 154 return node.target not in self.non_matchable_methods 155 else: 156 return False 157 158 159class GraphMatchingException(Exception): 160 """ 161 Exception raised when two graphs cannot be matched. 162 """ 163 164 165class SubgraphTypeRelationship(enum.Enum): 166 # same type, known 167 # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d 168 EQUAL = enum.auto() 169 # same type, but the type is not known to Numerical Suite 170 # (user defined type, etc). 171 EQUAL_BUT_UKNOWN = enum.auto() 172 # known, same subgraph_relationship set, but not the same type 173 # example: F.linear and toq.linear 174 RELATED_BUT_NOT_EQUAL = enum.auto() 175 # not related 176 NOT_RELATED = enum.auto() 177 178 179def _get_subgraph_relationship_type( 180 subgraph_a: NSSubgraph, 181 subgraph_b: NSSubgraph, 182 gm_a: GraphModule, 183 gm_b: GraphModule, 184 type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]], 185) -> SubgraphTypeRelationship: 186 node_a = subgraph_a.base_op_node 187 node_b = subgraph_b.base_op_node 188 189 # TODO(next): make this code handle matching by what is before the base op 190 if node_a.op != node_b.op: 191 if not ( 192 node_a.op in ("call_function", "call_method") 193 and node_b.op in ("call_function", "call_method") 194 ): 195 return SubgraphTypeRelationship.NOT_RELATED 196 197 if node_a.op in ("call_function", "call_method"): 198 key = (node_a.target, node_b.target) 199 200 if key not in type_a_related_to_b: 201 if node_a.target == node_b.target: 202 return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN 203 else: 204 return SubgraphTypeRelationship.NOT_RELATED 205 # after this point, we are dealing with known types 206 207 if node_a.target == node_b.target: 208 node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node 209 node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node 210 if node_a_has_prev and (not node_b_has_prev): 211 return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL 212 elif (not node_a_has_prev) and node_b_has_prev: 213 return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL 214 elif (not node_a_has_prev) and (not node_b_has_prev): 215 return SubgraphTypeRelationship.EQUAL 216 else: 217 # TODO(future PR): check for matches start_op_node and base_op_node 218 return SubgraphTypeRelationship.EQUAL 219 220 if key in type_a_related_to_b: 221 return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL 222 else: 223 return SubgraphTypeRelationship.NOT_RELATED 224 elif node_a.op == "call_module": 225 assert ( 226 subgraph_a.base_op_node == subgraph_a.start_node 227 and subgraph_b.base_op_node == subgraph_b.start_node 228 ), "Matching call_module patterns where base_op_node != start_node is not supported yet" 229 # for call_module, we need to look up the modules to do the type check 230 assert isinstance(node_a.target, str) 231 mod_a = getattr_from_fqn(gm_a, node_a.target) 232 assert isinstance(node_b.target, str) 233 mod_b = getattr_from_fqn(gm_b, node_b.target) 234 235 key = (type(mod_a), type(mod_b)) 236 237 if key not in type_a_related_to_b: 238 if type(mod_a) == type(mod_b): 239 return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN 240 else: 241 return SubgraphTypeRelationship.NOT_RELATED 242 elif type(mod_a) == type(mod_b): 243 return SubgraphTypeRelationship.EQUAL 244 else: 245 return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL 246 247 return SubgraphTypeRelationship.NOT_RELATED 248 249 250def _get_name_for_subgraph( 251 subgraph_a: NSSubgraph, 252 gm_a: GraphModule, 253 base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]], 254 existing_names: Set[str], 255) -> str: 256 """ 257 Returns a unique name for a subgraph. This name is based on two things: 258 1. the name of the set containing the underlying type of the base op in the 259 subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op) 260 2. the number of previous subgraphs with related underlying type of the base op 261 262 For example, in the graph 263 264 linear0 -> relu0 -> linear1 -> relu1 265 266 The subgraphs are (linear0, relu0) and (linear1, relu1). If we iterate 267 from the output node backwards, the name given to (linear1, relu1) will be 268 `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0) 269 will be `base_op_torch.nn.functional.linear_1`. 270 271 Why are we not just using the node name? Answer: because of two requirements: 272 A. fusions must be supported 273 B. some Numeric Suite APIs can be called without having all of the models in memory 274 275 For example, let's say we need to match nodes of 276 277 (1) ... -> linear0 -> relu0 -> ... 278 279 And 280 281 (2) ... -> linear_relu0 -> ... 282 283 Without being able to inspect them together. With the current naming scheme, if 284 we iterate through both of these graphs in the same order, and assuming the rest 285 of the graphs match, both of these subgraphs will get the same name without 286 (1) and (2) knowing anything about each other. 287 """ 288 target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a) 289 target_base_type = None 290 for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items(): 291 if target_type in sets_of_related_ops: 292 target_base_type = base_name 293 target_base_name = "base_op_" + str(target_base_type) 294 counter = 0 295 proposed_name = target_base_name + "_" + str(counter) 296 while proposed_name in existing_names: 297 counter += 1 298 proposed_name = target_base_name + "_" + str(counter) 299 existing_names.add(proposed_name) 300 return proposed_name 301 302 303def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]: 304 if node.op in ("call_function", "call_method"): 305 return node.target 306 elif node.op == "call_module": 307 assert isinstance(node.target, str) 308 mod = getattr_from_fqn(gm, node.target) 309 return type(mod) 310 return None 311 312 313def get_matching_subgraph_pairs( 314 gm_a: GraphModule, 315 gm_b: GraphModule, 316 base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 317 unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None, 318) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]: 319 """ 320 Matches matchable subgraphs of graph_a to graph_b. 321 322 For a node, "matchable" is defined as a node which is not an observer, 323 fake_quants, quant or dequant. 324 325 A subgraph can contain one or more nodes. A subgraph is matchable if 326 at least one node inside of it is matchable. Currently, all nodes in 327 a subgraph must be matchable (because we assume no observers will be 328 inserted in the middle of a fusion). 329 330 A subgraph is defined by (start_node, end_node). We assume that only 331 start_node and end_node are linked with the surrounding graph, all other 332 nodes in a subgraph are self-contained. 333 334 A pair of nodes is "related" if both nodes represent the same mathematical 335 operation across different quantization flavors. For example, 336 `F.linear` and `torch.ops.quantized.linear` are related, and 337 `F.linear` and `torch.nn.Conv` are not related. 338 339 For each matchable pair of nodes node_a and node_b, they will match 340 if node_a and node_b are related. 341 342 For graphs A and B, they will match iff: 343 1. the number of matchable subgraphs in A and B is equivalent 344 2. when iterating through the matchable subgraphs of A and B in the same order, each 345 corresponding pair of base nodes is related. 346 347 This enables us to find the corresponding subgraphs between 348 graphs of related models. For example, if we had two graphs such as: 349 350 graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1 351 w -/ 352 b -/ 353 354 graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1 355 packed_params_0 -/ 356 357 This function will return the following result: 358 { 359 'conv_0': ( # the name of the node in graph_b 360 (conv_0, conv_0), # (start_node_a, end_node_a) 361 (qconv_0, qconv_0), # (start_node_b, end_node_b) 362 ), 363 } 364 365 Or, if we have a fusion pattern, 366 367 graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1 368 w -/ 369 b -/ 370 371 graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1 372 packed_params_0 -/ 373 374 This function will return the following result: 375 { 376 'linear_relu_0': ( # the name of the node in graph_b 377 (linear_0, relu_0), # (start_node_a, end_node_a) 378 (linear_relu_0, linear_relu_0), # (start_node_b, end_node_b) 379 ), 380 } 381 """ 382 if unmatchable_types_map is None: 383 unmatchable_types_map = get_unmatchable_types_map() 384 non_matchable_functions = unmatchable_types_map["funs_unmatchable"] 385 non_matchable_modules = unmatchable_types_map["mods_unmatchable"] 386 non_matchable_methods = unmatchable_types_map["meths_unmatchable"] 387 388 graph_a_iterator = _NSGraphMatchableSubgraphsIterator( 389 gm_a, non_matchable_functions, non_matchable_modules, non_matchable_methods 390 ) 391 graph_b_iterator = _NSGraphMatchableSubgraphsIterator( 392 gm_b, non_matchable_functions, non_matchable_modules, non_matchable_methods 393 ) 394 results = collections.OrderedDict() 395 if base_name_to_sets_of_related_ops is None: 396 base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops() 397 type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops) 398 399 existing_names_a: Set[str] = set() 400 existing_names_b: Set[str] = set() 401 402 while True: 403 # fetch the next subgraphs from a and b 404 cur_subgraph_a, cur_subgraph_b = None, None 405 try: 406 cur_subgraph_a = next(graph_a_iterator) 407 except StopIteration: 408 pass 409 try: 410 cur_subgraph_b = next(graph_b_iterator) 411 except StopIteration: 412 pass 413 414 # look up types of a and b for useful error messages 415 type_start_a, type_start_b = None, None 416 if cur_subgraph_a is not None: 417 type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a) 418 if cur_subgraph_b is not None: 419 type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b) 420 421 # check for results and determine what to do next 422 if cur_subgraph_a is not None and cur_subgraph_b is not None: 423 # both nodes were fetched, check for subgraph_relationship 424 # note: subgraph_relationship is checked on the start node, i.e. 425 # if a linear-relu pattern is checked, we would check for subgraph_relationship 426 # of the linear 427 subgraph_relationship = _get_subgraph_relationship_type( 428 cur_subgraph_a, cur_subgraph_b, gm_a, gm_b, type_a_related_to_b 429 ) 430 if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED: 431 msg = f""" 432The subgraphs 433({cur_subgraph_a}, {type_start_a}) and 434({cur_subgraph_b}, {type_start_b}) 435are not related. Please ensure that the two models you pass in have the same number 436of subgraphs, and each pair of subgraphs is related to each other.""" 437 raise GraphMatchingException(msg) 438 elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN: 439 # skip matching but unknown types 440 continue 441 key_name_a = _get_name_for_subgraph( 442 cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops, existing_names_a 443 ) 444 key_name_b = _get_name_for_subgraph( 445 cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b 446 ) 447 assert ( 448 key_name_a == key_name_b 449 ), f"Subgraph names {key_name_a} and {key_name_b} do not match" 450 results[key_name_a] = (cur_subgraph_a, cur_subgraph_b) 451 continue 452 elif cur_subgraph_a is None and cur_subgraph_b is None: 453 # we reached the end of both graphs 454 break 455 else: 456 # only one node was fetched, no match possible, throw error 457 msg = f""" 458Attempting to match 459({cur_subgraph_a}, {type_start_a}) and 460({cur_subgraph_b}, {type_start_b}), 461one of which is empty. Please ensure that the two models you pass in have the same number 462of subgraphs.""" 463 raise GraphMatchingException(msg) 464 465 # The subgraph pairs are originally created by traversing the two graphs 466 # from the outputs to the inputs. Reverse the results to return the 467 # subgraphs in their order of execution. 468 results = collections.OrderedDict(reversed(list(results.items()))) 469 470 return results 471