1# mypy: allow-untyped-defs 2import sys 3from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Type 4 5import torch 6from torch.ao.quantization.qconfig import QConfigAny 7from torch.ao.quantization.utils import MatchAllNode, Pattern 8from torch.fx.graph import Graph, Node 9from torch.nn.utils.parametrize import type_before_parametrizations 10 11from .graph_module import _is_observed_standalone_module 12from .quantize_handler import QuantizeHandler 13 14 15__all__: List[str] = [] 16 17# TODO(future PR): the 1st argument is typed as `List[Node]`, but a better type 18# would be a recursive `List[Union[Node, Tuple[Union[Node, ...]]]]` 19_MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler] 20 21_MatchResultWithQConfig = Tuple[ 22 Node, List[Node], Optional[Pattern], QuantizeHandler, QConfigAny 23] 24 25 26# Note: The order of patterns is important! match function will take whatever is matched first, so we'll 27# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu. 28# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns, 29# we'll start from the last node of the graph and traverse back. 30def _is_match(modules, node, pattern, max_uses=sys.maxsize): 31 """Matches a node in fx against a pattern""" 32 if isinstance(pattern, tuple): 33 self_match, *arg_matches = pattern 34 if self_match is getattr: 35 assert len(pattern) == 2, "Expecting getattr pattern to have two elements" 36 arg_matches = [] 37 else: 38 self_match = pattern 39 arg_matches = [] 40 41 if isinstance(self_match, type) and issubclass(self_match, MatchAllNode): 42 return True 43 44 if node == pattern: 45 return True 46 47 if not isinstance(node, Node) or len(node.users) > max_uses: 48 return False 49 50 if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module): 51 if node.op != "call_module": 52 return False 53 if not type_before_parametrizations(modules[node.target]) == self_match: 54 return False 55 elif callable(self_match): 56 if node.op != "call_function" or node.target is not self_match: 57 return False 58 elif node.target is getattr: 59 if node.args[1] != pattern[1]: 60 return False 61 elif isinstance(self_match, str): 62 if node.op != "call_method" or node.target != self_match: 63 return False 64 elif node.target != self_match: 65 return False 66 67 if not arg_matches: 68 return True 69 70 if len(arg_matches) != len(node.args): 71 return False 72 73 return all( 74 _is_match(modules, node, arg_match, max_uses=1) 75 for node, arg_match in zip(node.args, arg_matches) 76 ) 77 78 79def _find_matches( 80 graph: Graph, 81 modules: Dict[str, torch.nn.Module], 82 patterns: Dict[Pattern, QuantizeHandler], 83 root_node_getter_mapping: Dict[Pattern, Callable], 84 standalone_module_names: Optional[List[str]] = None, 85 standalone_module_classes: Optional[List[Type]] = None, 86 custom_module_classes: Optional[List[Any]] = None, 87) -> Dict[str, _MatchResult]: 88 """ 89 Matches the nodes in the input graph to quantization patterns, and 90 outputs the information needed to quantize them in future steps. 91 92 Inputs: 93 - graph: an fx.Graph object 94 - modules: a mapping of fully qualified module name to instance, 95 for example, {'foo': ModuleFoo, ...} 96 - patterns: a mapping from a tuple of nodes in reverse order to 97 uninitialized QuantizeHandler subclass. 98 99 Outputs a map of 100 node_name -> 101 (node, matched_values, matched_pattern, QuantizeHandler instance, 102 qconfig) 103 104 For example, { 105 'relu_1': (relu_1, [relu_1], torch.nn.functional.relu, 106 <CopyNodeQuantizeHandler instance>, QConfig(...)), 107 ... 108 } 109 """ 110 if custom_module_classes is None: 111 custom_module_classes = [] 112 113 if standalone_module_classes is None: 114 standalone_module_classes = [] 115 116 if standalone_module_names is None: 117 standalone_module_names = [] 118 119 match_map: Dict[str, _MatchResult] = {} 120 all_matched: Set[str] = set() 121 122 def _recursive_record_node_in_match_map( 123 last_node, match_map, node_pattern, matched_node_pattern, pattern, match_value 124 ): 125 if isinstance(node_pattern, Node): 126 match_map[node_pattern.name] = ( 127 last_node, 128 matched_node_pattern, 129 pattern, 130 match_value, 131 ) 132 elif not isinstance(node_pattern, Iterable): 133 return 134 else: 135 for n in node_pattern: 136 _recursive_record_node_in_match_map( 137 last_node, match_map, n, matched_node_pattern, pattern, match_value 138 ) 139 140 # TODO: 1. merge with fuse matcher 2. document the code 141 def record_match(pattern, node, last_node, matched_node_pattern, match_map): 142 if isinstance(pattern, tuple): 143 s, *args = pattern 144 is_single_arg = len(args) == 1 145 current_node_pattern: List[Node] = [] 146 record_match(s, node, last_node, matched_node_pattern, match_map) 147 if pattern[0] is not getattr: 148 for subpattern, arg in zip(args, node.args): 149 record_match(subpattern, arg, node, current_node_pattern, match_map) 150 if len(current_node_pattern) > 1: 151 # current_node_pattern is the node pattern we get from matching 152 # the subpattern with arguments of the node 153 # we use is_single_arg to recover the original structure of the pattern 154 # if the original pattern has a single argument, we will have 155 # (original_op, (original_arg, ...)) 156 # otherwise, we'll have a list of arguments 157 # (original_op, arg0, arg1, arg2, ...) 158 if is_single_arg: 159 matched_node_pattern.append(tuple(current_node_pattern)) 160 else: 161 matched_node_pattern.extend(list(current_node_pattern)) 162 else: 163 matched_node_pattern.append(current_node_pattern[0]) 164 else: 165 matched_node_pattern.append(node) 166 167 for node in reversed(graph.nodes): 168 if node.name not in match_map and node.name not in all_matched: 169 for pattern, quantize_handler_cls in patterns.items(): 170 root_node_getter = root_node_getter_mapping.get(pattern, None) 171 if _is_match(modules, node, pattern) and node.name not in match_map: 172 matched_node_pattern: List[Node] = [] 173 record_match(pattern, node, node, matched_node_pattern, match_map) 174 quantize_handler = quantize_handler_cls( # type: ignore[operator] 175 matched_node_pattern, modules, root_node_getter 176 ) 177 last_node = node 178 # record the match for all nodes in the pattern 179 _recursive_record_node_in_match_map( 180 last_node, 181 match_map, 182 # we need to record all nodes in the matched pattern in the match_map 183 matched_node_pattern, 184 # this is a part of the value corresponding to the node 185 matched_node_pattern, 186 pattern, 187 quantize_handler, 188 ) 189 break 190 191 # add custom module instances to the match result 192 assert modules is not None 193 for node in graph.nodes: 194 if ( 195 node.op == "call_module" 196 and type(modules[node.target]) in custom_module_classes 197 ): 198 match_map[node.name] = ( 199 node, 200 node, 201 None, 202 QuantizeHandler(node, modules, is_custom_module=True), 203 ) 204 205 def is_standalone_module(node_target: str, modules: Dict[str, torch.nn.Module]): 206 assert modules is not None 207 return ( 208 node_target in standalone_module_names 209 or type(modules[node_target]) # type: ignore[operator] 210 in standalone_module_classes # type: ignore[operator] 211 ) 212 213 # add standalone modules to the match 214 for node in graph.nodes: 215 if node.op == "call_module" and ( 216 is_standalone_module(node.target, modules) 217 or _is_observed_standalone_module(modules[node.target]) 218 ): 219 # add node to matched nodes 220 match_map[node.name] = ( 221 node, 222 node, 223 None, 224 QuantizeHandler(node, modules, is_standalone_module=True), 225 ) 226 227 return match_map 228