1from .graph_module import GraphModule 2from .graph import Graph 3from .node import Node 4from ._symbolic_trace import symbolic_trace 5from ._compatibility import compatibility 6 7import copy 8from dataclasses import dataclass 9from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING 10import torch 11 12if TYPE_CHECKING: 13 from .passes.utils.matcher_with_name_node_map_utils import InternalMatch 14 15__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"] 16 17@compatibility(is_backward_compatible=True) 18class Match(NamedTuple): 19 # Node from which the match was found 20 anchor: Node 21 # Maps nodes in the pattern subgraph to nodes in the larger graph 22 nodes_map: Dict[Node, Node] 23 24@compatibility(is_backward_compatible=False) 25@dataclass 26class ReplacedPatterns: 27 # Node from which the match was found 28 anchor: Node 29 # Maps nodes in the pattern subgraph to nodes in the larger graph 30 nodes_map: Dict[Node, Node] 31 # List of nodes that were added into the graph 32 replacements: List[Node] 33 34def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: 35 gm.delete_all_unused_submodules() 36 37 if isinstance(replacement, GraphModule): 38 replacement.graph.lint() 39 40 def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: 41 module_path, _, attr_name = target.rpartition(".") 42 try: 43 mod: torch.nn.Module = gm.get_submodule(module_path) 44 except AttributeError: 45 return None 46 attr = getattr(mod, attr_name, None) 47 return attr 48 49 for node in gm.graph.nodes: 50 if node.op == "call_module" or node.op == "get_attr": 51 52 gm_attr = try_get_attr(gm, node.target) 53 replacement_attr = try_get_attr(replacement, node.target) 54 55 # CASE 1: This target already exists as an attribute in our 56 # result GraphModule. Whether or not it exists in 57 # `replacement`, the existing submodule takes precedence. 58 if gm_attr is not None: 59 continue 60 61 # CASE 2: The target exists as an attribute in `replacement` 62 # only, so we need to copy it over. 63 elif replacement_attr is not None: 64 new_attr = copy.deepcopy(replacement_attr) 65 if isinstance(replacement_attr, torch.nn.Module): 66 gm.add_submodule(node.target, new_attr) 67 else: 68 setattr(gm, node.target, new_attr) 69 70 # CASE 3: The target doesn't exist as an attribute in `gm` 71 # or `replacement` 72 else: 73 raise RuntimeError('Attempted to create a "', node.op, 74 '" node during subgraph rewriting ' 75 f"with target {node.target}, but " 76 "the referenced attribute does not " 77 "exist in the replacement GraphModule") 78 79 gm.graph.lint() 80 81 82@compatibility(is_backward_compatible=True) 83def replace_pattern( 84 gm: GraphModule, 85 pattern: Union[Callable, GraphModule], 86 replacement: Union[Callable, GraphModule] 87) -> List[Match]: 88 """ 89 Matches all possible non-overlapping sets of operators and their 90 data dependencies (``pattern``) in the Graph of a GraphModule 91 (``gm``), then replaces each of these matched subgraphs with another 92 subgraph (``replacement``). 93 94 Args: 95 ``gm``: The GraphModule that wraps the Graph to operate on 96 ``pattern``: The subgraph to match in ``gm`` for replacement 97 ``replacement``: The subgraph to replace ``pattern`` with 98 99 Returns: 100 List[Match]: A list of ``Match`` objects representing the places 101 in the original graph that ``pattern`` was matched to. The list 102 is empty if there are no matches. ``Match`` is defined as: 103 104 .. code-block:: python 105 106 class Match(NamedTuple): 107 # Node from which the match was found 108 anchor: Node 109 # Maps nodes in the pattern subgraph to nodes in the larger graph 110 nodes_map: Dict[Node, Node] 111 112 Examples: 113 114 .. code-block:: python 115 116 import torch 117 from torch.fx import symbolic_trace, subgraph_rewriter 118 119 class M(torch.nn.Module): 120 def __init__(self) -> None: 121 super().__init__() 122 123 def forward(self, x, w1, w2): 124 m1 = torch.cat([w1, w2]).sum() 125 m2 = torch.cat([w1, w2]).sum() 126 return x + torch.max(m1) + torch.max(m2) 127 128 def pattern(w1, w2): 129 return torch.cat([w1, w2]).sum() 130 131 def replacement(w1, w2): 132 return torch.stack([w1, w2]) 133 134 traced_module = symbolic_trace(M()) 135 136 subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) 137 138 The above code will first match ``pattern`` in the ``forward`` 139 method of ``traced_module``. Pattern-matching is done based on 140 use-def relationships, not node names. For example, if you had 141 ``p = torch.cat([a, b])`` in ``pattern``, you could match 142 ``m = torch.cat([a, b])`` in the original ``forward`` function, 143 despite the variable names being different (``p`` vs ``m``). 144 145 The ``return`` statement in ``pattern`` is matched based on its 146 value only; it may or may not match to the ``return`` statement in 147 the larger graph. In other words, the pattern doesn't have to extend 148 to the end of the larger graph. 149 150 When the pattern is matched, it will be removed from the larger 151 function and replaced by ``replacement``. If there are multiple 152 matches for ``pattern`` in the larger function, each non-overlapping 153 match will be replaced. In the case of a match overlap, the first 154 found match in the set of overlapping matches will be replaced. 155 ("First" here being defined as the first in a topological ordering 156 of the Nodes' use-def relationships. In most cases, the first Node 157 is the parameter that appears directly after ``self``, while the 158 last Node is whatever the function returns.) 159 160 One important thing to note is that the parameters of the 161 ``pattern`` Callable must be used in the Callable itself, 162 and the parameters of the ``replacement`` Callable must match 163 the pattern. The first rule is why, in the above code block, the 164 ``forward`` function has parameters ``x, w1, w2``, but the 165 ``pattern`` function only has parameters ``w1, w2``. ``pattern`` 166 doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. 167 As an example of the second rule, consider replacing 168 169 .. code-block:: python 170 171 def pattern(x, y): 172 return torch.neg(x) + torch.relu(y) 173 174 with 175 176 .. code-block:: python 177 178 def replacement(x, y): 179 return torch.relu(x) 180 181 In this case, ``replacement`` needs the same number of parameters 182 as ``pattern`` (both ``x`` and ``y``), even though the parameter 183 ``y`` isn't used in ``replacement``. 184 185 After calling ``subgraph_rewriter.replace_pattern``, the generated 186 Python code looks like this: 187 188 .. code-block:: python 189 190 def forward(self, x, w1, w2): 191 stack_1 = torch.stack([w1, w2]) 192 sum_1 = stack_1.sum() 193 stack_2 = torch.stack([w1, w2]) 194 sum_2 = stack_2.sum() 195 max_1 = torch.max(sum_1) 196 add_1 = x + max_1 197 max_2 = torch.max(sum_2) 198 add_2 = add_1 + max_2 199 return add_2 200 """ 201 match_and_replacements = _replace_pattern(gm, pattern, replacement) 202 return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements] 203 204 205# Experimental API, not backward compatible 206@compatibility(is_backward_compatible=False) 207def replace_pattern_with_filters( 208 gm: GraphModule, 209 pattern: Union[Callable, Graph, GraphModule], 210 replacement: Union[Callable, Graph, GraphModule], 211 match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, 212 ignore_literals: bool = False, 213) -> List[ReplacedPatterns]: 214 """ 215 See replace_pattern for documentation. This function is an overload with an additional match_filter argument. 216 217 Args: 218 ``match_filters``: A list of functions that take in 219 (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating 220 whether the match satisfies the condition. 221 See matcher_utils.py for definition of InternalMatch. 222 """ 223 224 return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals) 225 226 227def _replace_pattern( 228 gm: GraphModule, 229 pattern: Union[Callable, Graph, GraphModule], 230 replacement: Union[Callable, Graph, GraphModule], 231 match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None, 232 ignore_literals: bool = False, 233) -> List[ReplacedPatterns]: 234 235 from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch 236 237 if match_filters is None: 238 match_filters = [] 239 240 # Get the graphs for `gm`, `pattern`, `replacement` 241 original_graph: Graph = gm.graph 242 243 if isinstance(pattern, GraphModule): 244 pattern_graph = pattern.graph 245 elif isinstance(pattern, Graph): 246 pattern_graph = pattern 247 else: 248 pattern_graph = symbolic_trace(pattern).graph 249 250 if isinstance(replacement, GraphModule): 251 replacement_graph = replacement.graph 252 elif isinstance(replacement, Graph): 253 replacement_graph = replacement 254 else: 255 replacement_graph = symbolic_trace(replacement).graph 256 257 matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, 258 remove_overlapping_matches=True, ignore_literals=ignore_literals) 259 _matches: List[InternalMatch] = matcher.match(original_graph) 260 261 # Filter out matches that don't match the filter 262 _matches = [ 263 m for m in _matches 264 if all(match_filter(m, original_graph, pattern_graph) 265 for match_filter in match_filters) 266 ] 267 268 replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] 269 270 # As we progressively replace nodes, we'll need to keep track of how the match results should change 271 match_changed_node: Dict[Node, Node] = {} 272 273 match_and_replacements = [] 274 for match in _matches: 275 276 # Build connecting between replacement graph's input and original graph input producer node 277 278 # Initialize `val_map` with mappings from placeholder nodes in 279 # `replacement` to their corresponding node in `original_graph` 280 assert len(match.placeholder_nodes) == len(replacement_placeholders) 281 val_map: Dict[Node, Node] = {} 282 for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): 283 if isinstance(gn, Node): 284 val_map[rn] = match_changed_node.get(gn, gn) 285 if gn != val_map[rn]: 286 # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn 287 gn_ind = match.placeholder_nodes.index(gn) 288 match.placeholder_nodes[gn_ind] = match_changed_node[gn] 289 map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)] 290 match.nodes_map[map_key] = match_changed_node[gn] 291 else: 292 val_map[rn] = gn 293 294 # Copy the replacement graph over 295 user_nodes: Set[Node] = set() 296 for n in match.returning_nodes: 297 user_nodes.update(n.users) 298 assert user_nodes, "The returning_nodes should have at least one user node" 299 300 if len(user_nodes) == 1: 301 first_user_node = next(iter(user_nodes)) 302 else: 303 # If there are multiple user nodes, we need to find the first user node 304 # in the current execution order of the `original_graph` 305 for n in original_graph.nodes: 306 if n in user_nodes: 307 first_user_node = n 308 break 309 310 with original_graph.inserting_before(first_user_node): # type: ignore[possibly-undefined] 311 copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) 312 313 if isinstance(copied_returning_nodes, Node): 314 copied_returning_nodes = (copied_returning_nodes, ) 315 316 # Get a list of nodes that have been replaced into the graph 317 replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes] 318 319 # Hook the output Node of the replacement subgraph in to the 320 # original Graph at the correct location 321 assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type] 322 for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type] 323 gn.replace_all_uses_with(copied_node) 324 match_changed_node[gn] = copied_node 325 # Remove the original nodes 326 for node in reversed(pattern_graph.nodes): 327 if node.op != "placeholder" and node.op != "output": 328 gn = match.nodes_map[node] 329 gm.graph.erase_node(gn) 330 331 match_and_replacements.append( 332 ReplacedPatterns( 333 anchor=match.anchors[0], 334 nodes_map=match.nodes_map, 335 replacements=replacement_nodes 336 ) 337 ) 338 339 # Update the passed-in GraphModule to reflect the new state of 340 # `original_graph` 341 gm.recompile() 342 343 # If `replacement` was an nn.Module, we'll need to make sure that 344 # all the submodules have been copied over correctly 345 if isinstance(replacement, torch.nn.Module): 346 _replace_attributes(gm, replacement) 347 348 return match_and_replacements 349