xref: /aosp_15_r20/external/pytorch/torch/ao/quantization/fx/match_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import sys
3from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type
4
5import torch
6from torch.ao.quantization.qconfig import QConfigAny
7from torch.ao.quantization.utils import MatchAllNode, Pattern
8from torch.fx.graph import Graph, Node
9from torch.nn.utils.parametrize import type_before_parametrizations
10
11from .graph_module import _is_observed_standalone_module
12from .quantize_handler import QuantizeHandler
13
14
15__all__: List[str] = []
16
17# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type
18# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]`
19_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler]
20
21_MatchResultWithQConfig = Tuple[
22    Node, List[Node], Optional[Pattern], QuantizeHandler, QConfigAny
23]
24
25
26# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
27# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
28# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
29# we'll start from the last node of the graph and traverse back.
30def _is_match(modules, node, pattern, max_uses=sys.maxsize):
31    """Matches a node in fx against a pattern"""
32    if isinstance(pattern, tuple):
33        self_match, *arg_matches = pattern
34        if self_match is getattr:
35            assert len(pattern) == 2, "Expecting getattr pattern to have two elements"
36            arg_matches = []
37    else:
38        self_match = pattern
39        arg_matches = []
40
41    if isinstance(self_match, type) and issubclass(self_match, MatchAllNode):
42        return True
43
44    if node == pattern:
45        return True
46
47    if not isinstance(node, Node) or len(node.users) > max_uses:
48        return False
49
50    if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
51        if node.op != "call_module":
52            return False
53        if not type_before_parametrizations(modules[node.target]) == self_match:
54            return False
55    elif callable(self_match):
56        if node.op != "call_function" or node.target is not self_match:
57            return False
58        elif node.target is getattr:
59            if node.args[1] != pattern[1]:
60                return False
61    elif isinstance(self_match, str):
62        if node.op != "call_method" or node.target != self_match:
63            return False
64    elif node.target != self_match:
65        return False
66
67    if not arg_matches:
68        return True
69
70    if len(arg_matches) != len(node.args):
71        return False
72
73    return all(
74        _is_match(modules, node, arg_match, max_uses=1)
75        for node, arg_match in zip(node.args, arg_matches)
76    )
77
78
79def _find_matches(
80    graph: Graph,
81    modules: Dict[str, torch.nn.Module],
82    patterns: Dict[Pattern, QuantizeHandler],
83    root_node_getter_mapping: Dict[Pattern, Callable],
84    standalone_module_names: Optional[List[str]] = None,
85    standalone_module_classes: Optional[List[Type]] = None,
86    custom_module_classes: Optional[List[Any]] = None,
87) -> Dict[str, _MatchResult]:
88    """
89    Matches the nodes in the input graph to quantization patterns, and
90    outputs the information needed to quantize them in future steps.
91
92    Inputs:
93      - graph: an fx.Graph object
94      - modules: a mapping of fully qualified module name to instance,
95          for example, {'foo': ModuleFoo, ...}
96      - patterns: a mapping from a tuple of nodes in reverse order to
97          uninitialized QuantizeHandler subclass.
98
99    Outputs a map of
100      node_name ->
101        (node, matched_values, matched_pattern, QuantizeHandler instance,
102         qconfig)
103
104    For example, {
105      'relu_1': (relu_1, [relu_1], torch.nn.functional.relu,
106                 <CopyNodeQuantizeHandler instance>, QConfig(...)),
107      ...
108    }
109    """
110    if custom_module_classes is None:
111        custom_module_classes = []
112
113    if standalone_module_classes is None:
114        standalone_module_classes = []
115
116    if standalone_module_names is None:
117        standalone_module_names = []
118
119    match_map: Dict[str, _MatchResult] = {}
120    all_matched: Set[str] = set()
121
122    def _recursive_record_node_in_match_map(
123        last_node, match_map, node_pattern, matched_node_pattern, pattern, match_value
124    ):
125        if isinstance(node_pattern, Node):
126            match_map[node_pattern.name] = (
127                last_node,
128                matched_node_pattern,
129                pattern,
130                match_value,
131            )
132        elif not isinstance(node_pattern, Iterable):
133            return
134        else:
135            for n in node_pattern:
136                _recursive_record_node_in_match_map(
137                    last_node, match_map, n, matched_node_pattern, pattern, match_value
138                )
139
140    # TODO: 1. merge with fuse matcher 2. document the code
141    def record_match(pattern, node, last_node, matched_node_pattern, match_map):
142        if isinstance(pattern, tuple):
143            s, *args = pattern
144            is_single_arg = len(args) == 1
145            current_node_pattern: List[Node] = []
146            record_match(s, node, last_node, matched_node_pattern, match_map)
147            if pattern[0] is not getattr:
148                for subpattern, arg in zip(args, node.args):
149                    record_match(subpattern, arg, node, current_node_pattern, match_map)
150            if len(current_node_pattern) > 1:
151                # current_node_pattern is  the node pattern we get from matching
152                # the subpattern with arguments of the node
153                # we use is_single_arg to recover the original structure of the pattern
154                # if the original pattern has a single argument, we will have
155                # (original_op, (original_arg, ...))
156                # otherwise, we'll have a list of arguments
157                # (original_op, arg0, arg1, arg2, ...)
158                if is_single_arg:
159                    matched_node_pattern.append(tuple(current_node_pattern))
160                else:
161                    matched_node_pattern.extend(list(current_node_pattern))
162            else:
163                matched_node_pattern.append(current_node_pattern[0])
164        else:
165            matched_node_pattern.append(node)
166
167    for node in reversed(graph.nodes):
168        if node.name not in match_map and node.name not in all_matched:
169            for pattern, quantize_handler_cls in patterns.items():
170                root_node_getter = root_node_getter_mapping.get(pattern, None)
171                if _is_match(modules, node, pattern) and node.name not in match_map:
172                    matched_node_pattern: List[Node] = []
173                    record_match(pattern, node, node, matched_node_pattern, match_map)
174                    quantize_handler = quantize_handler_cls(  # type: ignore[operator]
175                        matched_node_pattern, modules, root_node_getter
176                    )
177                    last_node = node
178                    # record the match for all nodes in the pattern
179                    _recursive_record_node_in_match_map(
180                        last_node,
181                        match_map,
182                        # we need to record all nodes in the matched pattern in the match_map
183                        matched_node_pattern,
184                        # this is a part of the value corresponding to the node
185                        matched_node_pattern,
186                        pattern,
187                        quantize_handler,
188                    )
189                    break
190
191    # add custom module instances to the match result
192    assert modules is not None
193    for node in graph.nodes:
194        if (
195            node.op == "call_module"
196            and type(modules[node.target]) in custom_module_classes
197        ):
198            match_map[node.name] = (
199                node,
200                node,
201                None,
202                QuantizeHandler(node, modules, is_custom_module=True),
203            )
204
205    def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]):
206        assert modules is not None
207        return (
208            node_target in standalone_module_names
209            or type(modules[node_target])  # type: ignore[operator]
210            in standalone_module_classes  # type: ignore[operator]
211        )
212
213    # add standalone modules to the match
214    for node in graph.nodes:
215        if node.op == "call_module" and (
216            is_standalone_module(node.target, modules)
217            or _is_observed_standalone_module(modules[node.target])
218        ):
219            # add node to matched nodes
220            match_map[node.name] = (
221                node,
222                node,
223                None,
224                QuantizeHandler(node, modules, is_standalone_module=True),
225            )
226
227    return match_map
228