xref: /aosp_15_r20/external/pytorch/torch/fx/passes/utils/matcher_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from dataclasses import dataclass, field
3from collections import defaultdict
4import copy
5import torch
6from torch.fx import (
7    Node,
8    Graph,
9)
10from torch.fx._compatibility import compatibility
11from typing import Dict, List, Set, Any, Union, Tuple
12import logging
13import os
14
15__all__ = ['SubgraphMatcher', 'InternalMatch']
16
17# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs
18def _init_logger():
19    logger = logging.getLogger(__name__)
20
21    level = os.environ.get('PYTORCH_MATCHER_LOGLEVEL', 'WARNING').upper()
22    logger.setLevel(level)
23    console = logging.StreamHandler()
24    formatter = logging.Formatter("%(filename)s > %(message)s")
25    console.setFormatter(formatter)
26    console.setLevel(level)
27    # add the handlers to the logger
28    logger.addHandler(console)
29    logger.propagate = False
30    return logger
31
32logger = _init_logger()
33
34@compatibility(is_backward_compatible=False)
35@dataclass
36class InternalMatch:
37    # Nodes from which the match was found
38    anchors: List[Node]
39    # Maps nodes in the pattern subgraph to nodes in the larger graph
40    nodes_map: Dict[Node, Node] = field(default_factory=dict)
41
42    # nodes in target graph that are matched placeholder in pattern
43    placeholder_nodes: List[Node] = field(default_factory=list)
44
45    # nodes in matched subgraph returned by output
46    returning_nodes: List[Node] = field(default_factory=list)
47
48    # map from a string name to a node in the target graph
49    # only available if the matcher is `SubgraphMatcherWithNameNodesMap`
50    name_node_map: Dict[str, Node] = field(default_factory=dict)
51
52    def __copy__(self):
53        return InternalMatch(anchors=self.anchors, nodes_map=self.nodes_map.copy(),
54                             placeholder_nodes=self.placeholder_nodes.copy(),
55                             returning_nodes=self.returning_nodes.copy())
56
57@compatibility(is_backward_compatible=False)
58class SubgraphMatcher:
59    def __init__(self, pattern: Graph,
60                 match_output: bool = False,
61                 match_placeholder: bool = False,
62                 remove_overlapping_matches: bool = True,
63                 ignore_literals: bool = False) -> None:
64        """
65        Args:
66            pattern: the targeted matching pattern, represented in fx.Graph.
67            match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern.
68                If False, output node is ignored during match.
69            match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of
70                the targeted pattern. If False, placeholder nodes will be used a wildcard.
71            remove_overlapping_matches: If True, in the case of overlapping matches, only the first match
72                will be returned.
73            ignore_literals: If True, will not check if literals are equal and
74                will instead treat them as wildcards.
75        """
76
77        self.pattern = pattern
78        self.match_output = match_output
79        self.match_placeholder = match_placeholder
80        self.remove_overlapping_matches = remove_overlapping_matches
81        self.ignore_literals = ignore_literals
82
83        if len(pattern.nodes) == 0:
84            raise ValueError("SubgraphMatcher cannot be initialized with an empty pattern")
85
86        for node in pattern.nodes:
87            if node.op != "output":
88                assert len(node.users) > 0, \
89                       "SubgraphMatcher cannot be initialized with an pattern with dead code"
90
91        # TODO: assert pattern is a connected graph
92
93        self.pattern_placeholder_nodes = [n for n in pattern.nodes if n.op == "placeholder"]
94        output_node = next(iter(reversed(pattern.nodes)))
95        # nodes returned by outputs
96        self.pattern_returning_nodes: List[Node] = output_node.all_input_nodes
97
98        self.pattern_anchors: List[Node] = []
99        if match_output:
100            self.pattern_anchors = [output_node]
101        else:
102            # If a node has output_node as the ONLY user, then this node is a graph sink,
103            # and should be matched against as an anchor
104            self.pattern_anchors = [n for n in output_node.all_input_nodes if len(n.users) == 1]
105
106    def _match_attributes(self, pn: Node, gn: Node) -> bool:
107        # Attributes matching is complicated. Right now we only support matching constant tensor
108        assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string."
109        assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string."
110
111        # TODO(tmanlaibaatar) should probably make this actual API
112        def _getattr(model: torch.fx.GraphModule, attr_name: str):
113            *prefix, field = attr_name.split(".")
114            t = model
115            for item in prefix:
116                t = getattr(t, item, None)  # type: ignore[assignment]
117                assert t is not None
118
119            return getattr(t, field)
120
121        pn_value = _getattr(pn.graph.owning_module, pn.target)
122        gn_value = _getattr(gn.graph.owning_module, gn.target)
123
124        if type(pn_value) != type(gn_value):
125            return False
126
127        # Don't require exact match on tensor values.
128        if isinstance(pn_value, torch.Tensor):
129            return isinstance(gn_value, torch.Tensor)
130        else:
131            raise RuntimeError(f"Unsupported type {pn_value} when matching attributes")
132        return False
133
134    def _nodes_are_equal(self, pn: Node, gn: Node) -> bool:
135        # if exact match for placeholder is not required, then use placeholder as a wildcard
136        if not self.match_placeholder and pn.op == "placeholder":
137            return True
138
139        if pn.op == gn.op:
140            if pn.op == "placeholder" or pn.op == "output":
141                return True
142            elif pn.op == "get_attr":
143                return self._match_attributes(pn, gn)
144            return pn.target == gn.target
145        return False
146
147    def _is_contained(self, nodes_map: Dict[Node, Node]) -> bool:
148        # `lookup` represents all the nodes in `original_graph`
149        # that are part of `pattern`
150
151        # Placeholders can be used by other nodes in the graphs
152        lookup: Dict[Node, Node] = {gn : pn for pn, gn in nodes_map.items() if pn.op != "placeholder"}
153
154        for gn, pn in lookup.items():
155            # nodes returned by output are allowed to be used in other areas of the graph
156            if pn in self.pattern_returning_nodes:
157                continue
158
159            for user in gn.users:
160                # If this node has users that were not in `lookup`, then it must leak out of the
161                # pattern subgraph
162                if user not in lookup:
163                    return False
164        return True
165
166    def _remove_overlapping_matches(self, matches: List[InternalMatch]) -> List[InternalMatch]:
167        non_overlapping_matches: List[InternalMatch] = []
168        nodes_matched: Set[Node] = set()
169
170        for match in matches:
171            found_overlap = False
172            for pn, gn in match.nodes_map.items():
173                if pn.op not in {"placeholder", "output"} and gn in nodes_matched:
174                    found_overlap = True
175                    break
176
177            if not found_overlap:
178                non_overlapping_matches.append(match)
179                for pn, gn in match.nodes_map.items():
180                    if pn.op not in {"placeholder", "output"}:
181                        nodes_matched.add(gn)
182        return non_overlapping_matches
183
184    def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool:
185        assert not (isinstance(pn, Node) and isinstance(gn, Node)), "pn and gn cannot both be Node"
186
187        if isinstance(pn, Node) and not isinstance(gn, Node):
188            if pn.op == "placeholder":
189                # Check if we've already matched these nodes in the current
190                # traversal
191                if pn in match.nodes_map:
192                    return match.nodes_map[pn] == gn
193
194                match.nodes_map[pn] = gn
195                return True
196            else:
197                return False
198        elif not isinstance(pn, Node) and isinstance(gn, Node):
199            return False
200        else:
201            return type(gn) == type(pn) and gn == pn
202
203    def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool:
204        logger.info("  matching %s to %s", pn, gn)
205
206        assert isinstance(pn, Node) and isinstance(gn, Node), str(f"pn and gn must be Node, pn: {pn}, gn: {gn}")
207
208        # Check if we've already matched these nodes in the current
209        # traversal
210        if pn in match.nodes_map:
211            return match.nodes_map[pn] == gn
212
213        # TODO: use a more efficient way to check if gn is matched before: two-way dict
214        if gn in match.nodes_map.values():
215            return False
216
217        if not self._nodes_are_equal(pn, gn):
218            return False
219
220        # Optimistically mark `pn` as a match for `gn`, and save a local copy of match
221        saved_match = copy.copy(match)
222        match.nodes_map[pn] = gn
223
224        # Placeholder is a wildcard and can be matched with any python object
225        # (including list/tuple)
226        if pn.op == "placeholder":
227            return True
228
229        # Recursively traverse upwards to check if `pn` is a true
230        # match for `gn`
231        match_found = True
232
233        def _match_args(args1: Union[List, Tuple], args2: Union[List, Tuple]) -> bool:
234            if len(args1) != len(args2):
235                return False
236
237            for a1, a2 in zip(args1, args2):
238                if isinstance(a1, Node) and isinstance(a2, Node):
239                    matched = self._match_nodes(a1, a2, match)
240                elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)):
241                    matched = _match_args(a1, a2)
242                else:
243                    matched = self._match_literals(a1, a2, match) or self.ignore_literals
244
245                if not matched:
246                    return False
247
248            return True
249
250        # Flatten all args/kwargs into 1 list of args
251        pn_args, gn_args = None, None
252        if (
253            (len(pn.args) != len(gn.args) or list(pn.kwargs.keys()) != list(gn.kwargs.keys())) and
254            pn.op == "call_function" and
255            isinstance(pn.target, torch._ops.OpOverload)
256        ):
257            args_schema = pn.target._schema.arguments
258
259            def get_all_arguments(orig_args, orig_kwargs):
260                all_args = []
261                for i, schema in enumerate(args_schema):
262                    if schema.name in orig_kwargs:
263                        all_args.append(orig_kwargs[schema.name])
264                    elif not schema.kwarg_only and i < len(orig_args):
265                        all_args.append(orig_args[i])
266                    else:
267                        all_args.append(schema.default_value)
268                return all_args
269
270            pn_args = get_all_arguments(pn.args, pn.kwargs)
271            gn_args = get_all_arguments(gn.args, gn.kwargs)
272
273        elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list(gn.kwargs.keys()):
274            pn_args = list(pn.args)
275            gn_args = list(gn.args)
276            pn_args.extend(list(pn.kwargs.values()))
277            gn_args.extend(list(gn.kwargs.values()))
278        else:
279            match_found = False
280
281        match_found = (
282            match_found and
283            pn_args is not None and
284            gn_args is not None and
285            _match_args(pn_args, gn_args)
286        )
287
288        if not match_found:
289            # revert to saved_match before matching with current node
290            match = copy.copy(saved_match)
291            return False
292
293        return True
294
295    def match(self, graph: Graph) -> List[InternalMatch]:
296        """
297        Returns:
298            The matched subgraphs.
299            Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder
300            and nodes returned by output) can only be consumed by nodes within the matched subgraph.
301
302        Subgraph pattern matcher is implemented with the backtracking style in the following steps:
303
304        1. We first identify all the anchor nodes in the pattern graph. The anchor nodes
305        are the "sinks" (nodes with no user other than the output node) of the pattern graph.
306        One pattern graph could have multiple anchors if it has multiple return values.
307
308        2. In the target graph, we identify the potential candidate nodes that can be matched
309        with each anchor. These anchor-candidate pairs are the starting points for
310        pairwise per-node matching.
311
312        3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both
313        pattern and target graphs. For every pattern nodes along traversal path, we compare it
314        against the target nodes. In case any comparison failed, the match for this anchor-candidate
315        pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes`
316        for more details.
317
318        4. In the case of multiple anchors, every anchor will need to find a match using step 3.
319        In addition, the matches found between anchors need to have a common intersection node
320        in order for the match to be valid. This is implemented with backtracking. See `backtracking`
321        for more details.
322
323        Notice: graph traversal must be done in the reverser order because a tensor can have multiple
324        consumers, but can only have a single producer. Only with reverser order, we can we jointly
325        traverse the pattern and target graph in a deterministic path.
326
327        Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However,
328        in practice, it's unlikely to blow up.
329
330        """
331        from torch.fx.passes.utils.fuser_utils import validate_partition
332
333        # find candidate nodes to match with pattern anchors
334        match_candidates: Dict[Node, List[Node]] = defaultdict(list)
335        for pattern_anchor in self.pattern_anchors:
336            for node in graph.nodes:
337                if self._nodes_are_equal(pattern_anchor, node):
338                    match_candidates[pattern_anchor].append(node)
339        match_candidates_list = list(match_candidates.items())
340
341        logger.info("Initial match_candidates_list: %s\n", match_candidates_list)
342
343        matches: List[InternalMatch] = []
344
345        def backtracking(anchor_index, match):
346            if anchor_index == len(match_candidates_list):
347                match.placeholder_nodes = [match.nodes_map[pn] for pn in self.pattern_placeholder_nodes]
348                match.returning_nodes = [match.nodes_map[pn] for pn in self.pattern_returning_nodes]
349                matches.append(match)
350
351                logger.info("Found a match: %s\n", match)
352                return
353
354            pattern_anchor, candidate_nodes = match_candidates_list[anchor_index]
355            saved_match = copy.copy(match)
356
357            for node in candidate_nodes:
358                logger.info("Trying to match anchor %s to %s", pattern_anchor, node)
359
360                match_found = self._match_nodes(pattern_anchor, node, match)
361                if match_found:
362                    # match next anchor
363                    backtracking(anchor_index + 1, match)
364                else:
365                    logger.info("Failed to match anchor %s to %s\n", pattern_anchor, node)
366
367                # revert to saved_match before matching with current anchor
368                match = copy.copy(saved_match)
369
370        match = InternalMatch(anchors=self.pattern_anchors)
371        if match_candidates_list:
372            backtracking(0, match)
373
374        # filter out the matches where the subgraph is not fully_contained
375        before = len(matches)
376        matches = [match for match in matches if self._is_contained(match.nodes_map)]
377        after = len(matches)
378        if before != after:
379            logger.info("Filtered out %s matches because they are not fully contained", before - after)
380
381        # filter out the matches that form a cycle if the subgraph is fused
382        valid_matches = []
383        for match in matches:
384            matched_compute_nodes = \
385                [gn for pn, gn in match.nodes_map.items() if pn.op not in {"placeholder", "output"}]
386            if validate_partition(matched_compute_nodes):
387                valid_matches.append(match)
388        if len(valid_matches) != len(matches):
389            logger.info("Filtered out %s matches because \
390                          matched subgraph would form a cycle if fused", len(matches) - len(valid_matches))
391
392        if self.remove_overlapping_matches:
393            before = len(valid_matches)
394            matches = self._remove_overlapping_matches(valid_matches)
395            after = len(matches)
396            if before != after:
397                logger.info("Filtered out %s matches because matched subgraphs are overlapping", before - after)
398
399        logger.info("Matches returned: %s", matches)
400
401        return matches
402