xref: /aosp_15_r20/external/pytorch/torch/fx/passes/utils/matcher_with_name_node_map_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import Dict, List, Tuple
2
3from torch.fx import Graph, GraphModule, Node
4from torch.fx._compatibility import compatibility
5
6from .matcher_utils import InternalMatch, SubgraphMatcher
7
8
9__all__ = ["SubgraphMatcherWithNameNodeMap"]
10
11
12def _split_to_graph_and_name_node_map(
13    gm: GraphModule,
14) -> Tuple[GraphModule, Dict[str, Node]]:
15    from torch.fx.graph import _PyTreeInfo
16    from torch.utils._pytree import tree_flatten, tree_unflatten
17
18    name_node_map = {}
19    for n in gm.graph.nodes:
20        if n.op == "output":
21            assert gm._out_spec is not None
22            output = tree_unflatten(n.args[0], gm._out_spec)
23            assert isinstance(
24                output, tuple
25            ), "Expecting the pattern graph to return a tuple"
26            assert (
27                len(output) >= 2
28            ), "Expecting the pattern graph to have at least two outputs"
29            *out, name_node_map = output
30            flattened, out_spec = tree_flatten(out)
31            assert isinstance(
32                name_node_map, Dict
33            ), "Expecting the input graph to have a dict output as the last element"
34            n.args = (flattened,)
35            orig_pytree_info = gm._graph._codegen.pytree_info  # type: ignore[attr-defined]
36            gm._graph._codegen.pytree_info = _PyTreeInfo(  # type: ignore[attr-defined]
37                orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec
38            )
39    gm.recompile()
40    return gm, name_node_map
41
42
43@compatibility(is_backward_compatible=False)
44class SubgraphMatcherWithNameNodeMap(SubgraphMatcher):
45    """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name,
46    this requires pattern to have specific format (returning and additional dictionary at the output,
47    that has node name as key, and the node in the pattern graph as value, see Example for more details)
48
49    Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during
50    initialization since we need to modify the graph (which requires `recompile` the GraphModule)
51
52    Example::
53        def pattern(x, weight):
54            conv = F.conv2d(x, weight)
55            relu = F.relu(conv)
56            return relu, {"conv": conv, "relu": relu}
57
58        def target_graph(x, weight):
59            conv = F.conv2d(x, weight)
60            relu = F.relu(conv)
61            relu *= 2
62            return relu
63
64        pattern_gm = capture_pre_autograd_graph(pattern, example_inputs)
65        target_gm = capture_pre_autograd_graph(target_graph, example_inputs)
66        matcher = SubgraphMatcherWithNameNodeMap(pattern_gm)
67        matches = matcher.match(target_gm)
68        for match in matches:
69            match.name_node_map["conv"].meta["annotation"] = ...
70
71    """
72
73    def __init__(
74        self,
75        pattern_gm: GraphModule,
76        match_output: bool = False,
77        match_placeholder: bool = False,
78        remove_overlapping_matches: bool = True,
79        ignore_literals: bool = False,
80    ) -> None:
81        pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm)
82        self.name_node_map = name_node_map
83        super().__init__(
84            pattern_gm.graph,
85            match_output,
86            match_placeholder,
87            remove_overlapping_matches,
88            ignore_literals,
89        )
90
91    def match(self, graph: Graph) -> List[InternalMatch]:
92        """The returned InternalMatch will have name_node_map populated with a map
93        from node name (str) to the target node, e.g.
94        {"conv": target_conv_ndoe, "relu": target_relu_node}
95
96        this requires the pattern graph returns an additional
97        output of node name to node, e.g. instead of:
98        ```
99        def pattern(...):
100            ...
101            return relu
102        ```
103        we should do:
104        ```
105        def pattern(...):
106            ...
107            return relu, {"conv": conv, "relu": relu}
108        ``` instead
109        """
110        internal_matches = super().match(graph)
111        for internal_match in internal_matches:
112            for k, n in self.name_node_map.items():
113                internal_match.name_node_map[k] = internal_match.nodes_map[n]
114        return internal_matches
115