xref: /aosp_15_r20/external/pytorch/torch/ao/ns/fx/graph_matcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import collections
3import enum
4from typing import Any, Dict, List, Optional, Set, Tuple
5
6import torch
7from torch.ao.quantization import FakeQuantizeBase, ObserverBase
8from torch.ao.quantization.utils import getattr_from_fqn
9from torch.fx import GraphModule
10from torch.fx.graph import Graph, Node
11
12from .mappings import get_base_name_to_sets_of_related_ops, get_unmatchable_types_map
13from .ns_types import NSNodeTargetType, NSSubgraph
14from .pattern_utils import (
15    end_node_matches_reversed_fusion,
16    get_reversed_fusions,
17    get_type_a_related_to_b,
18)
19
20
21toq = torch.ops.quantized
22
23
24def _get_output_nodes(g: Graph) -> List[Node]:
25    return [n for n in g.nodes if n.op == "output"]
26
27
28class _NSGraphMatchableSubgraphsIterator:
29    """
30    Iterates through the graph of gm, starting with the output nodes
31    and continuing backwards.
32    1. Returns matchable subgraphs, in order. A subgraph is defined by
33       (start_node, end_node).
34    2. Skips over non-matchable subgraphs
35    """
36
37    def __init__(
38        self,
39        gm: GraphModule,
40        non_matchable_functions: Set[NSNodeTargetType],
41        non_matchable_modules: Set[NSNodeTargetType],
42        non_matchable_methods: Set[NSNodeTargetType],
43    ):
44        self.gm: GraphModule = gm
45        self.non_matchable_functions: Set[NSNodeTargetType] = non_matchable_functions
46        self.non_matchable_modules: Set[NSNodeTargetType] = non_matchable_modules
47        self.non_matchable_methods: Set[NSNodeTargetType] = non_matchable_methods
48        self.seen_nodes: Set[Node] = set()
49        self.stack: List[Node] = []
50        for start_node in _get_output_nodes(self.gm.graph):
51            self.stack.append(start_node)
52
53    def __iter__(self):
54        return self
55
56    def __next__(self) -> NSSubgraph:
57        """
58        Returns the next matchable subgraph.
59        """
60        while len(self.stack) > 0:
61            cur_end_node = self.stack.pop()
62            if cur_end_node in self.seen_nodes:
63                continue
64
65            # for subgraphs which are single nodes, start_node == end_node
66            # for subgraphs with more than one node, start node != end_node
67            cur_start_node = cur_end_node
68            # Subgraphs like linear-relu have the base node as the start node.
69            # Subgraphs like dequantize-linear-relu-to(torch.float16) have the
70            #   base node as the second node.
71            # The cur_base_op_node var will move to the actual node during
72            #   the fusion matching later in this code block.
73            cur_base_op_node = cur_end_node
74
75            # Check for potential fusions. For now, we are greedy
76            # and always skip all non-base nodes of a fusion.  For example,
77            # if we match linear-relu backwards, we will always skip the
78            # relu node and attempt to match the linear node.  This can
79            # be made configurable later if needed.
80            for _reverse_fusion_ops, base_op_idx in get_reversed_fusions():
81                is_match = end_node_matches_reversed_fusion(
82                    cur_end_node, _reverse_fusion_ops, self.gm, self.seen_nodes
83                )
84                if is_match:
85                    # navigate to the base node
86                    for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
87                        self.seen_nodes.add(cur_start_node)
88                        # for now, assume that there are no other nodes
89                        # which need to be added to the stack
90                        cur_start_node = cur_start_node.args[0]  # type: ignore[assignment]
91                        # if the base op index matches the current node, set it
92                        rev_base_op_idx = len(_reverse_fusion_ops) - 2 - base_op_idx
93                        if rev_fusion_idx == rev_base_op_idx:
94                            cur_base_op_node = cur_start_node
95                    break
96
97            self.seen_nodes.add(cur_start_node)
98            # add args of previous nodes to stack
99            for arg in cur_start_node.all_input_nodes:
100                self._recursively_add_node_arg_to_stack(arg)
101
102            # skip unmatchable nodes
103            # note: this check is done on the start_node, i.e.
104            # if we are matching linear-relu in reverse, this would do the matchable
105            # check on the linear
106            if not self._is_matchable(cur_base_op_node):
107                continue
108
109            # If an observer or a fake_quant was not matched as a part of
110            # a pattern of multiple nodes, ignore it. One case where this is
111            # relevant is an observer on a graph input, which was added because
112            # it is necessary for the next node.
113            if cur_end_node.op == "call_module" and cur_start_node is cur_end_node:
114                maybe_obs = getattr_from_fqn(self.gm, cur_end_node.target)  # type: ignore[arg-type]
115                if isinstance(maybe_obs, (ObserverBase, FakeQuantizeBase)):
116                    continue
117
118            return NSSubgraph(
119                start_node=cur_start_node,
120                end_node=cur_end_node,
121                base_op_node=cur_base_op_node,
122            )
123
124        raise StopIteration
125
126    def _recursively_add_node_arg_to_stack(self, arg: Any) -> None:
127        """
128        Adds all of the nodes in this arg to the stack, properly navigating
129        through list, dicts and tuples.
130        """
131        if isinstance(arg, Node):
132            self.stack.append(arg)
133        elif (
134            isinstance(arg, torch.fx.immutable_collections.immutable_list)
135            or type(arg) is tuple
136        ):
137            for inner_arg in arg:
138                self._recursively_add_node_arg_to_stack(inner_arg)
139        elif isinstance(arg, torch.fx.immutable_collections.immutable_dict):
140            for value in arg.values():
141                self._recursively_add_node_arg_to_stack(value)
142
143    def _is_matchable(self, node: Node) -> bool:
144        if node.op == "call_function":
145            return node.target not in self.non_matchable_functions
146        elif node.op == "call_module":
147            assert isinstance(node.target, str)
148            target_mod = getattr_from_fqn(self.gm, node.target)
149            return not any(
150                isinstance(target_mod, t)  # type: ignore[arg-type]
151                for t in self.non_matchable_modules
152            )
153        elif node.op == "call_method":
154            return node.target not in self.non_matchable_methods
155        else:
156            return False
157
158
159class GraphMatchingException(Exception):
160    """
161    Exception raised when two graphs cannot be matched.
162    """
163
164
165class SubgraphTypeRelationship(enum.Enum):
166    # same type, known
167    # example: F.linear and F.linear, or nn.Conv2d and nn.Conv2d
168    EQUAL = enum.auto()
169    # same type, but the type is not known to Numerical Suite
170    # (user defined type, etc).
171    EQUAL_BUT_UKNOWN = enum.auto()
172    # known, same subgraph_relationship set, but not the same type
173    # example: F.linear and toq.linear
174    RELATED_BUT_NOT_EQUAL = enum.auto()
175    # not related
176    NOT_RELATED = enum.auto()
177
178
179def _get_subgraph_relationship_type(
180    subgraph_a: NSSubgraph,
181    subgraph_b: NSSubgraph,
182    gm_a: GraphModule,
183    gm_b: GraphModule,
184    type_a_related_to_b: Set[Tuple[NSNodeTargetType, NSNodeTargetType]],
185) -> SubgraphTypeRelationship:
186    node_a = subgraph_a.base_op_node
187    node_b = subgraph_b.base_op_node
188
189    # TODO(next): make this code handle matching by what is before the base op
190    if node_a.op != node_b.op:
191        if not (
192            node_a.op in ("call_function", "call_method")
193            and node_b.op in ("call_function", "call_method")
194        ):
195            return SubgraphTypeRelationship.NOT_RELATED
196
197    if node_a.op in ("call_function", "call_method"):
198        key = (node_a.target, node_b.target)
199
200        if key not in type_a_related_to_b:
201            if node_a.target == node_b.target:
202                return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
203            else:
204                return SubgraphTypeRelationship.NOT_RELATED
205        # after this point, we are dealing with known types
206
207        if node_a.target == node_b.target:
208            node_a_has_prev = subgraph_a.base_op_node == subgraph_a.start_node
209            node_b_has_prev = subgraph_b.base_op_node == subgraph_b.start_node
210            if node_a_has_prev and (not node_b_has_prev):
211                return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
212            elif (not node_a_has_prev) and node_b_has_prev:
213                return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
214            elif (not node_a_has_prev) and (not node_b_has_prev):
215                return SubgraphTypeRelationship.EQUAL
216            else:
217                # TODO(future PR): check for matches start_op_node and base_op_node
218                return SubgraphTypeRelationship.EQUAL
219
220        if key in type_a_related_to_b:
221            return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
222        else:
223            return SubgraphTypeRelationship.NOT_RELATED
224    elif node_a.op == "call_module":
225        assert (
226            subgraph_a.base_op_node == subgraph_a.start_node
227            and subgraph_b.base_op_node == subgraph_b.start_node
228        ), "Matching call_module patterns where base_op_node != start_node is not supported yet"
229        # for call_module, we need to look up the modules to do the type check
230        assert isinstance(node_a.target, str)
231        mod_a = getattr_from_fqn(gm_a, node_a.target)
232        assert isinstance(node_b.target, str)
233        mod_b = getattr_from_fqn(gm_b, node_b.target)
234
235        key = (type(mod_a), type(mod_b))
236
237        if key not in type_a_related_to_b:
238            if type(mod_a) == type(mod_b):
239                return SubgraphTypeRelationship.EQUAL_BUT_UKNOWN
240            else:
241                return SubgraphTypeRelationship.NOT_RELATED
242        elif type(mod_a) == type(mod_b):
243            return SubgraphTypeRelationship.EQUAL
244        else:
245            return SubgraphTypeRelationship.RELATED_BUT_NOT_EQUAL
246
247    return SubgraphTypeRelationship.NOT_RELATED
248
249
250def _get_name_for_subgraph(
251    subgraph_a: NSSubgraph,
252    gm_a: GraphModule,
253    base_name_to_sets_of_related_ops: Dict[str, Set[NSNodeTargetType]],
254    existing_names: Set[str],
255) -> str:
256    """
257    Returns a unique name for a subgraph. This name is based on two things:
258    1. the name of the set containing the underlying type of the base op in the
259       subgraph (i.e. 'torch.nn.functional.linear' if this is related to a linear op)
260    2. the number of previous subgraphs with related underlying type of the base op
261
262    For example, in the graph
263
264    linear0 -> relu0 -> linear1 -> relu1
265
266    The subgraphs are (linear0, relu0) and (linear1, relu1).  If we iterate
267    from the output node backwards, the name given to (linear1, relu1) will be
268    `base_op_torch.nn.functional.linear_0`, and the name given to (linear0, relu0)
269    will be `base_op_torch.nn.functional.linear_1`.
270
271    Why are we not just using the node name? Answer: because of two requirements:
272    A. fusions must be supported
273    B. some Numeric Suite APIs can be called without having all of the models in memory
274
275    For example, let's say we need to match nodes of
276
277    (1) ... -> linear0 -> relu0 -> ...
278
279    And
280
281    (2) ... -> linear_relu0 -> ...
282
283    Without being able to inspect them together. With the current naming scheme, if
284    we iterate through both of these graphs in the same order, and assuming the rest
285    of the graphs match, both of these subgraphs will get the same name without
286    (1) and (2) knowing anything about each other.
287    """
288    target_type = _get_node_target_type(subgraph_a.base_op_node, gm_a)
289    target_base_type = None
290    for base_name, sets_of_related_ops in base_name_to_sets_of_related_ops.items():
291        if target_type in sets_of_related_ops:
292            target_base_type = base_name
293    target_base_name = "base_op_" + str(target_base_type)
294    counter = 0
295    proposed_name = target_base_name + "_" + str(counter)
296    while proposed_name in existing_names:
297        counter += 1
298        proposed_name = target_base_name + "_" + str(counter)
299    existing_names.add(proposed_name)
300    return proposed_name
301
302
303def _get_node_target_type(node: Node, gm: GraphModule) -> Optional[NSNodeTargetType]:
304    if node.op in ("call_function", "call_method"):
305        return node.target
306    elif node.op == "call_module":
307        assert isinstance(node.target, str)
308        mod = getattr_from_fqn(gm, node.target)
309        return type(mod)
310    return None
311
312
313def get_matching_subgraph_pairs(
314    gm_a: GraphModule,
315    gm_b: GraphModule,
316    base_name_to_sets_of_related_ops: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
317    unmatchable_types_map: Optional[Dict[str, Set[NSNodeTargetType]]] = None,
318) -> Dict[str, Tuple[NSSubgraph, NSSubgraph]]:
319    """
320    Matches matchable subgraphs of graph_a to graph_b.
321
322    For a node, "matchable" is defined as a node which is not an observer,
323    fake_quants, quant or dequant.
324
325    A subgraph can contain one or more nodes.  A subgraph is matchable if
326    at least one node inside of it is matchable.  Currently, all nodes in
327    a subgraph must be matchable (because we assume no observers will be
328    inserted in the middle of a fusion).
329
330    A subgraph is defined by (start_node, end_node).  We assume that only
331    start_node and end_node are linked with the surrounding graph, all other
332    nodes in a subgraph are self-contained.
333
334    A pair of nodes is "related" if both nodes represent the same mathematical
335    operation across different quantization flavors. For example,
336    `F.linear` and `torch.ops.quantized.linear` are related, and
337    `F.linear` and `torch.nn.Conv` are not related.
338
339    For each matchable pair of nodes node_a and node_b, they will match
340    if node_a and node_b are related.
341
342    For graphs A and B, they will match iff:
343    1. the number of matchable subgraphs in A and B is equivalent
344    2. when iterating through the matchable subgraphs of A and B in the same order, each
345       corresponding pair of base nodes is related.
346
347    This enables us to find the corresponding subgraphs between
348    graphs of related models.  For example, if we had two graphs such as:
349
350    graph_a: x0 -> conv_0 (type: nn.Conv2d) -> obs_0 -> x1
351             w  -/
352             b  -/
353
354    graph_b: x0 -> quant_0 -> qconv_0 (type: nnq.Conv2d) -> dequant_0 -> x1
355           packed_params_0 -/
356
357    This function will return the following result:
358    {
359        'conv_0': (  # the name of the node in graph_b
360          (conv_0, conv_0),  # (start_node_a, end_node_a)
361          (qconv_0, qconv_0),  # (start_node_b, end_node_b)
362        ),
363    }
364
365    Or, if we have a fusion pattern,
366
367    graph_a: x0 -> linear_0 -> relu_0 -> obs_0 -> x1
368             w  -/
369             b  -/
370
371    graph_b: x0 -> quant_0 -> linear_relu_0 -> dequant_0 -> x1
372           packed_params_0 -/
373
374    This function will return the following result:
375    {
376        'linear_relu_0': (  # the name of the node in graph_b
377          (linear_0, relu_0),  # (start_node_a, end_node_a)
378          (linear_relu_0, linear_relu_0),  # (start_node_b, end_node_b)
379        ),
380    }
381    """
382    if unmatchable_types_map is None:
383        unmatchable_types_map = get_unmatchable_types_map()
384    non_matchable_functions = unmatchable_types_map["funs_unmatchable"]
385    non_matchable_modules = unmatchable_types_map["mods_unmatchable"]
386    non_matchable_methods = unmatchable_types_map["meths_unmatchable"]
387
388    graph_a_iterator = _NSGraphMatchableSubgraphsIterator(
389        gm_a, non_matchable_functions, non_matchable_modules, non_matchable_methods
390    )
391    graph_b_iterator = _NSGraphMatchableSubgraphsIterator(
392        gm_b, non_matchable_functions, non_matchable_modules, non_matchable_methods
393    )
394    results = collections.OrderedDict()
395    if base_name_to_sets_of_related_ops is None:
396        base_name_to_sets_of_related_ops = get_base_name_to_sets_of_related_ops()
397    type_a_related_to_b = get_type_a_related_to_b(base_name_to_sets_of_related_ops)
398
399    existing_names_a: Set[str] = set()
400    existing_names_b: Set[str] = set()
401
402    while True:
403        # fetch the next subgraphs from a and b
404        cur_subgraph_a, cur_subgraph_b = None, None
405        try:
406            cur_subgraph_a = next(graph_a_iterator)
407        except StopIteration:
408            pass
409        try:
410            cur_subgraph_b = next(graph_b_iterator)
411        except StopIteration:
412            pass
413
414        # look up types of a and b for useful error messages
415        type_start_a, type_start_b = None, None
416        if cur_subgraph_a is not None:
417            type_start_a = _get_node_target_type(cur_subgraph_a.start_node, gm_a)
418        if cur_subgraph_b is not None:
419            type_start_b = _get_node_target_type(cur_subgraph_b.start_node, gm_b)
420
421        # check for results and determine what to do next
422        if cur_subgraph_a is not None and cur_subgraph_b is not None:
423            # both nodes were fetched, check for subgraph_relationship
424            # note: subgraph_relationship is checked on the start node, i.e.
425            # if a linear-relu pattern is checked, we would check for subgraph_relationship
426            # of the linear
427            subgraph_relationship = _get_subgraph_relationship_type(
428                cur_subgraph_a, cur_subgraph_b, gm_a, gm_b, type_a_related_to_b
429            )
430            if subgraph_relationship == SubgraphTypeRelationship.NOT_RELATED:
431                msg = f"""
432The subgraphs
433({cur_subgraph_a}, {type_start_a}) and
434({cur_subgraph_b}, {type_start_b})
435are not related. Please ensure that the two models you pass in have the same number
436of subgraphs, and each pair of subgraphs is related to each other."""
437                raise GraphMatchingException(msg)
438            elif subgraph_relationship == SubgraphTypeRelationship.EQUAL_BUT_UKNOWN:
439                # skip matching but unknown types
440                continue
441            key_name_a = _get_name_for_subgraph(
442                cur_subgraph_a, gm_a, base_name_to_sets_of_related_ops, existing_names_a
443            )
444            key_name_b = _get_name_for_subgraph(
445                cur_subgraph_b, gm_b, base_name_to_sets_of_related_ops, existing_names_b
446            )
447            assert (
448                key_name_a == key_name_b
449            ), f"Subgraph names {key_name_a} and {key_name_b} do not match"
450            results[key_name_a] = (cur_subgraph_a, cur_subgraph_b)
451            continue
452        elif cur_subgraph_a is None and cur_subgraph_b is None:
453            # we reached the end of both graphs
454            break
455        else:
456            # only one node was fetched, no match possible, throw error
457            msg = f"""
458Attempting to match
459({cur_subgraph_a}, {type_start_a}) and
460({cur_subgraph_b}, {type_start_b}),
461one of which is empty. Please ensure that the two models you pass in have the same number
462of subgraphs."""
463            raise GraphMatchingException(msg)
464
465    # The subgraph pairs are originally created by traversing the two graphs
466    # from the outputs to the inputs. Reverse the results to return the
467    # subgraphs in their order of execution.
468    results = collections.OrderedDict(reversed(list(results.items())))
469
470    return results
471