1from typing import Dict, List, Tuple 2 3from torch.fx import Graph, GraphModule, Node 4from torch.fx._compatibility import compatibility 5 6from .matcher_utils import InternalMatch, SubgraphMatcher 7 8 9__all__ = ["SubgraphMatcherWithNameNodeMap"] 10 11 12def _split_to_graph_and_name_node_map( 13 gm: GraphModule, 14) -> Tuple[GraphModule, Dict[str, Node]]: 15 from torch.fx.graph import _PyTreeInfo 16 from torch.utils._pytree import tree_flatten, tree_unflatten 17 18 name_node_map = {} 19 for n in gm.graph.nodes: 20 if n.op == "output": 21 assert gm._out_spec is not None 22 output = tree_unflatten(n.args[0], gm._out_spec) 23 assert isinstance( 24 output, tuple 25 ), "Expecting the pattern graph to return a tuple" 26 assert ( 27 len(output) >= 2 28 ), "Expecting the pattern graph to have at least two outputs" 29 *out, name_node_map = output 30 flattened, out_spec = tree_flatten(out) 31 assert isinstance( 32 name_node_map, Dict 33 ), "Expecting the input graph to have a dict output as the last element" 34 n.args = (flattened,) 35 orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] 36 gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined] 37 orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec 38 ) 39 gm.recompile() 40 return gm, name_node_map 41 42 43@compatibility(is_backward_compatible=False) 44class SubgraphMatcherWithNameNodeMap(SubgraphMatcher): 45 """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name, 46 this requires pattern to have specific format (returning and additional dictionary at the output, 47 that has node name as key, and the node in the pattern graph as value, see Example for more details) 48 49 Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during 50 initialization since we need to modify the graph (which requires `recompile` the GraphModule) 51 52 Example:: 53 def pattern(x, weight): 54 conv = F.conv2d(x, weight) 55 relu = F.relu(conv) 56 return relu, {"conv": conv, "relu": relu} 57 58 def target_graph(x, weight): 59 conv = F.conv2d(x, weight) 60 relu = F.relu(conv) 61 relu *= 2 62 return relu 63 64 pattern_gm = capture_pre_autograd_graph(pattern, example_inputs) 65 target_gm = capture_pre_autograd_graph(target_graph, example_inputs) 66 matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) 67 matches = matcher.match(target_gm) 68 for match in matches: 69 match.name_node_map["conv"].meta["annotation"] = ... 70 71 """ 72 73 def __init__( 74 self, 75 pattern_gm: GraphModule, 76 match_output: bool = False, 77 match_placeholder: bool = False, 78 remove_overlapping_matches: bool = True, 79 ignore_literals: bool = False, 80 ) -> None: 81 pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) 82 self.name_node_map = name_node_map 83 super().__init__( 84 pattern_gm.graph, 85 match_output, 86 match_placeholder, 87 remove_overlapping_matches, 88 ignore_literals, 89 ) 90 91 def match(self, graph: Graph) -> List[InternalMatch]: 92 """The returned InternalMatch will have name_node_map populated with a map 93 from node name (str) to the target node, e.g. 94 {"conv": target_conv_ndoe, "relu": target_relu_node} 95 96 this requires the pattern graph returns an additional 97 output of node name to node, e.g. instead of: 98 ``` 99 def pattern(...): 100 ... 101 return relu 102 ``` 103 we should do: 104 ``` 105 def pattern(...): 106 ... 107 return relu, {"conv": conv, "relu": relu} 108 ``` instead 109 """ 110 internal_matches = super().match(graph) 111 for internal_match in internal_matches: 112 for k, n in self.name_node_map.items(): 113 internal_match.name_node_map[k] = internal_match.nodes_map[n] 114 return internal_matches 115