1# mypy: allow-untyped-defs 2from dataclasses import dataclass, field 3from collections import defaultdict 4import copy 5import torch 6from torch.fx import ( 7 Node, 8 Graph, 9) 10from torch.fx._compatibility import compatibility 11from typing import Dict, List, Set, Any, Union, Tuple 12import logging 13import os 14 15__all__ = ['SubgraphMatcher', 'InternalMatch'] 16 17# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs 18def _init_logger(): 19 logger = logging.getLogger(__name__) 20 21 level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper() 22 logger.setLevel(level) 23 console = logging.StreamHandler() 24 formatter = logging.Formatter("%(filename)s > %(message)s") 25 console.setFormatter(formatter) 26 console.setLevel(level) 27 # add the handlers to the logger 28 logger.addHandler(console) 29 logger.propagate = False 30 return logger 31 32logger = _init_logger() 33 34@compatibility(is_backward_compatible=False) 35@dataclass 36class InternalMatch: 37 # Nodes from which the match was found 38 anchors: List[Node] 39 # Maps nodes in the pattern subgraph to nodes in the larger graph 40 nodes_map: Dict[Node, Node] = field(default_factory=dict) 41 42 # nodes in target graph that are matched placeholder in pattern 43 placeholder_nodes: List[Node] = field(default_factory=list) 44 45 # nodes in matched subgraph returned by output 46 returning_nodes: List[Node] = field(default_factory=list) 47 48 # map from a string name to a node in the target graph 49 # only available if the matcher is `SubgraphMatcherWithNameNodesMap` 50 name_node_map: Dict[str, Node] = field(default_factory=dict) 51 52 def __copy__(self): 53 return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(), 54 placeholder_nodes=self.placeholder_nodes.copy(), 55 returning_nodes=self.returning_nodes.copy()) 56 57@compatibility(is_backward_compatible=False) 58class SubgraphMatcher: 59 def __init__(self, pattern: Graph, 60 match_output: bool = False, 61 match_placeholder: bool = False, 62 remove_overlapping_matches: bool = True, 63 ignore_literals: bool = False) -> None: 64 """ 65 Args: 66 pattern: the targeted matching pattern, represented in fx.Graph. 67 match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern. 68 If False, output node is ignored during match. 69 match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of 70 the targeted pattern. If False, placeholder nodes will be used a wildcard. 71 remove_overlapping_matches: If True, in the case of overlapping matches, only the first match 72 will be returned. 73 ignore_literals: If True, will not check if literals are equal and 74 will instead treat them as wildcards. 75 """ 76 77 self.pattern = pattern 78 self.match_output = match_output 79 self.match_placeholder = match_placeholder 80 self.remove_overlapping_matches = remove_overlapping_matches 81 self.ignore_literals = ignore_literals 82 83 if len(pattern.nodes) == 0: 84 raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern") 85 86 for node in pattern.nodes: 87 if node.op != "output": 88 assert len(node.users) > 0, \ 89 "SubgraphMatcher cannot be initialized with an pattern with dead code" 90 91 # TODO: assert pattern is a connected graph 92 93 self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"] 94 output_node = next(iter(reversed(pattern.nodes))) 95 # nodes returned by outputs 96 self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes 97 98 self.pattern_anchors: List[Node] = [] 99 if match_output: 100 self.pattern_anchors = [output_node] 101 else: 102 # If a node has output_node as the ONLY user, then this node is a graph sink, 103 # and should be matched against as an anchor 104 self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1] 105 106 def _match_attributes(self, pn: Node, gn: Node) -> bool: 107 # Attributes matching is complicated. Right now we only support matching constant tensor 108 assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string." 109 assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string." 110 111 # TODO(tmanlaibaatar) should probably make this actual API 112 def _getattr(model: torch.fx.GraphModule, attr_name: str): 113 *prefix, field = attr_name.split(".") 114 t = model 115 for item in prefix: 116 t = getattr(t, item, None) # type: ignore[assignment] 117 assert t is not None 118 119 return getattr(t, field) 120 121 pn_value = _getattr(pn.graph.owning_module, pn.target) 122 gn_value = _getattr(gn.graph.owning_module, gn.target) 123 124 if type(pn_value) != type(gn_value): 125 return False 126 127 # Don't require exact match on tensor values. 128 if isinstance(pn_value, torch.Tensor): 129 return isinstance(gn_value, torch.Tensor) 130 else: 131 raise RuntimeError(f"Unsupported type {pn_value} when matching attributes") 132 return False 133 134 def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: 135 # if exact match for placeholder is not required, then use placeholder as a wildcard 136 if not self.match_placeholder and pn.op == "placeholder": 137 return True 138 139 if pn.op == gn.op: 140 if pn.op == "placeholder" or pn.op == "output": 141 return True 142 elif pn.op == "get_attr": 143 return self._match_attributes(pn, gn) 144 return pn.target == gn.target 145 return False 146 147 def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool: 148 # `lookup` represents all the nodes in `original_graph` 149 # that are part of `pattern` 150 151 # Placeholders can be used by other nodes in the graphs 152 lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"} 153 154 for gn, pn in lookup.items(): 155 # nodes returned by output are allowed to be used in other areas of the graph 156 if pn in self.pattern_returning_nodes: 157 continue 158 159 for user in gn.users: 160 # If this node has users that were not in `lookup`, then it must leak out of the 161 # pattern subgraph 162 if user not in lookup: 163 return False 164 return True 165 166 def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]: 167 non_overlapping_matches: List[InternalMatch] = [] 168 nodes_matched: Set[Node] = set() 169 170 for match in matches: 171 found_overlap = False 172 for pn, gn in match.nodes_map.items(): 173 if pn.op not in {"placeholder", "output"} and gn in nodes_matched: 174 found_overlap = True 175 break 176 177 if not found_overlap: 178 non_overlapping_matches.append(match) 179 for pn, gn in match.nodes_map.items(): 180 if pn.op not in {"placeholder", "output"}: 181 nodes_matched.add(gn) 182 return non_overlapping_matches 183 184 def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: 185 assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node" 186 187 if isinstance(pn, Node) and not isinstance(gn, Node): 188 if pn.op == "placeholder": 189 # Check if we've already matched these nodes in the current 190 # traversal 191 if pn in match.nodes_map: 192 return match.nodes_map[pn] == gn 193 194 match.nodes_map[pn] = gn 195 return True 196 else: 197 return False 198 elif not isinstance(pn, Node) and isinstance(gn, Node): 199 return False 200 else: 201 return type(gn) == type(pn) and gn == pn 202 203 def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: 204 logger.info(" matching %s to %s", pn, gn) 205 206 assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}") 207 208 # Check if we've already matched these nodes in the current 209 # traversal 210 if pn in match.nodes_map: 211 return match.nodes_map[pn] == gn 212 213 # TODO: use a more efficient way to check if gn is matched before: two-way dict 214 if gn in match.nodes_map.values(): 215 return False 216 217 if not self._nodes_are_equal(pn, gn): 218 return False 219 220 # Optimistically mark `pn` as a match for `gn`, and save a local copy of match 221 saved_match = copy.copy(match) 222 match.nodes_map[pn] = gn 223 224 # Placeholder is a wildcard and can be matched with any python object 225 # (including list/tuple) 226 if pn.op == "placeholder": 227 return True 228 229 # Recursively traverse upwards to check if `pn` is a true 230 # match for `gn` 231 match_found = True 232 233 def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool: 234 if len(args1) != len(args2): 235 return False 236 237 for a1, a2 in zip(args1, args2): 238 if isinstance(a1, Node) and isinstance(a2, Node): 239 matched = self._match_nodes(a1, a2, match) 240 elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)): 241 matched = _match_args(a1, a2) 242 else: 243 matched = self._match_literals(a1, a2, match) or self.ignore_literals 244 245 if not matched: 246 return False 247 248 return True 249 250 # Flatten all args/kwargs into 1 list of args 251 pn_args, gn_args = None, None 252 if ( 253 (len(pn.args) != len(gn.args) or list(pn.kwargs.keys()) != list(gn.kwargs.keys())) and 254 pn.op == "call_function" and 255 isinstance(pn.target, torch._ops.OpOverload) 256 ): 257 args_schema = pn.target._schema.arguments 258 259 def get_all_arguments(orig_args, orig_kwargs): 260 all_args = [] 261 for i, schema in enumerate(args_schema): 262 if schema.name in orig_kwargs: 263 all_args.append(orig_kwargs[schema.name]) 264 elif not schema.kwarg_only and i < len(orig_args): 265 all_args.append(orig_args[i]) 266 else: 267 all_args.append(schema.default_value) 268 return all_args 269 270 pn_args = get_all_arguments(pn.args, pn.kwargs) 271 gn_args = get_all_arguments(gn.args, gn.kwargs) 272 273 elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(gn.kwargs.keys()): 274 pn_args = list(pn.args) 275 gn_args = list(gn.args) 276 pn_args.extend(list(pn.kwargs.values())) 277 gn_args.extend(list(gn.kwargs.values())) 278 else: 279 match_found = False 280 281 match_found = ( 282 match_found and 283 pn_args is not None and 284 gn_args is not None and 285 _match_args(pn_args, gn_args) 286 ) 287 288 if not match_found: 289 # revert to saved_match before matching with current node 290 match = copy.copy(saved_match) 291 return False 292 293 return True 294 295 def match(self, graph: Graph) -> List[InternalMatch]: 296 """ 297 Returns: 298 The matched subgraphs. 299 Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder 300 and nodes returned by output) can only be consumed by nodes within the matched subgraph. 301 302 Subgraph pattern matcher is implemented with the backtracking style in the following steps: 303 304 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes 305 are the "sinks" (nodes with no user other than the output node) of the pattern graph. 306 One pattern graph could have multiple anchors if it has multiple return values. 307 308 2. In the target graph, we identify the potential candidate nodes that can be matched 309 with each anchor. These anchor-candidate pairs are the starting points for 310 pairwise per-node matching. 311 312 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both 313 pattern and target graphs. For every pattern nodes along traversal path, we compare it 314 against the target nodes. In case any comparison failed, the match for this anchor-candidate 315 pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes` 316 for more details. 317 318 4. In the case of multiple anchors, every anchor will need to find a match using step 3. 319 In addition, the matches found between anchors need to have a common intersection node 320 in order for the match to be valid. This is implemented with backtracking. See `backtracking` 321 for more details. 322 323 Notice: graph traversal must be done in the reverser order because a tensor can have multiple 324 consumers, but can only have a single producer. Only with reverser order, we can we jointly 325 traverse the pattern and target graph in a deterministic path. 326 327 Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However, 328 in practice, it's unlikely to blow up. 329 330 """ 331 from torch.fx.passes.utils.fuser_utils import validate_partition 332 333 # find candidate nodes to match with pattern anchors 334 match_candidates: Dict[Node, List[Node]] = defaultdict(list) 335 for pattern_anchor in self.pattern_anchors: 336 for node in graph.nodes: 337 if self._nodes_are_equal(pattern_anchor, node): 338 match_candidates[pattern_anchor].append(node) 339 match_candidates_list = list(match_candidates.items()) 340 341 logger.info("Initial match_candidates_list: %s\n", match_candidates_list) 342 343 matches: List[InternalMatch] = [] 344 345 def backtracking(anchor_index, match): 346 if anchor_index == len(match_candidates_list): 347 match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes] 348 match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes] 349 matches.append(match) 350 351 logger.info("Found a match: %s\n", match) 352 return 353 354 pattern_anchor, candidate_nodes = match_candidates_list[anchor_index] 355 saved_match = copy.copy(match) 356 357 for node in candidate_nodes: 358 logger.info("Trying to match anchor %s to %s", pattern_anchor, node) 359 360 match_found = self._match_nodes(pattern_anchor, node, match) 361 if match_found: 362 # match next anchor 363 backtracking(anchor_index + 1, match) 364 else: 365 logger.info("Failed to match anchor %s to %s\n", pattern_anchor, node) 366 367 # revert to saved_match before matching with current anchor 368 match = copy.copy(saved_match) 369 370 match = InternalMatch(anchors=self.pattern_anchors) 371 if match_candidates_list: 372 backtracking(0, match) 373 374 # filter out the matches where the subgraph is not fully_contained 375 before = len(matches) 376 matches = [match for match in matches if self._is_contained(match.nodes_map)] 377 after = len(matches) 378 if before != after: 379 logger.info("Filtered out %s matches because they are not fully contained", before - after) 380 381 # filter out the matches that form a cycle if the subgraph is fused 382 valid_matches = [] 383 for match in matches: 384 matched_compute_nodes = \ 385 [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}] 386 if validate_partition(matched_compute_nodes): 387 valid_matches.append(match) 388 if len(valid_matches) != len(matches): 389 logger.info("Filtered out %s matches because \ 390 matched subgraph would form a cycle if fused", len(matches) - len(valid_matches)) 391 392 if self.remove_overlapping_matches: 393 before = len(valid_matches) 394 matches = self._remove_overlapping_matches(valid_matches) 395 after = len(matches) 396 if before != after: 397 logger.info("Filtered out %s matches because matched subgraphs are overlapping", before - after) 398 399 logger.info("Matches returned: %s", matches) 400 401 return matches 402