xref: /aosp_15_r20/external/pytorch/torch/fx/subgraph_rewriter.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from .graph_module import GraphModule
2from .graph import Graph
3from .node import Node
4from ._symbolic_trace import symbolic_trace
5from ._compatibility import compatibility
6
7import copy
8from dataclasses import dataclass
9from typing import Any, Callable, Dict, List, NamedTuple, Optional, Set, Union, TYPE_CHECKING
10import torch
11
12if TYPE_CHECKING:
13    from .passes.utils.matcher_with_name_node_map_utils import InternalMatch
14
15__all__ = ['Match', 'replace_pattern', 'replace_pattern_with_filters', "ReplacedPatterns"]
16
17@compatibility(is_backward_compatible=True)
18class Match(NamedTuple):
19    # Node from which the match was found
20    anchor: Node
21    # Maps nodes in the pattern subgraph to nodes in the larger graph
22    nodes_map: Dict[Node, Node]
23
24@compatibility(is_backward_compatible=False)
25@dataclass
26class ReplacedPatterns:
27    # Node from which the match was found
28    anchor: Node
29    # Maps nodes in the pattern subgraph to nodes in the larger graph
30    nodes_map: Dict[Node, Node]
31    # List of nodes that were added into the graph
32    replacements: List[Node]
33
34def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None:
35    gm.delete_all_unused_submodules()
36
37    if isinstance(replacement, GraphModule):
38        replacement.graph.lint()
39
40    def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]:
41        module_path, _, attr_name = target.rpartition(".")
42        try:
43            mod: torch.nn.Module = gm.get_submodule(module_path)
44        except AttributeError:
45            return None
46        attr = getattr(mod, attr_name, None)
47        return attr
48
49    for node in gm.graph.nodes:
50        if node.op == "call_module" or node.op == "get_attr":
51
52            gm_attr = try_get_attr(gm, node.target)
53            replacement_attr = try_get_attr(replacement, node.target)
54
55            # CASE 1: This target already exists as an attribute in our
56            # result GraphModule. Whether or not it exists in
57            # `replacement`, the existing submodule takes precedence.
58            if gm_attr is not None:
59                continue
60
61            # CASE 2: The target exists as an attribute in `replacement`
62            # only, so we need to copy it over.
63            elif replacement_attr is not None:
64                new_attr = copy.deepcopy(replacement_attr)
65                if isinstance(replacement_attr, torch.nn.Module):
66                    gm.add_submodule(node.target, new_attr)
67                else:
68                    setattr(gm, node.target, new_attr)
69
70            # CASE 3: The target doesn't exist as an attribute in `gm`
71            # or `replacement`
72            else:
73                raise RuntimeError('Attempted to create a "', node.op,
74                                   '" node during subgraph rewriting '
75                                   f"with target {node.target}, but "
76                                   "the referenced attribute does not "
77                                   "exist in the replacement GraphModule")
78
79    gm.graph.lint()
80
81
82@compatibility(is_backward_compatible=True)
83def replace_pattern(
84    gm: GraphModule,
85    pattern: Union[Callable, GraphModule],
86    replacement: Union[Callable, GraphModule]
87) -> List[Match]:
88    """
89    Matches all possible non-overlapping sets of operators and their
90    data dependencies (``pattern``) in the Graph of a GraphModule
91    (``gm``), then replaces each of these matched subgraphs with another
92    subgraph (``replacement``).
93
94    Args:
95        ``gm``: The GraphModule that wraps the Graph to operate on
96        ``pattern``: The subgraph to match in ``gm`` for replacement
97        ``replacement``: The subgraph to replace ``pattern`` with
98
99    Returns:
100        List[Match]: A list of ``Match`` objects representing the places
101        in the original graph that ``pattern`` was matched to. The list
102        is empty if there are no matches. ``Match`` is defined as:
103
104        .. code-block:: python
105
106            class Match(NamedTuple):
107                # Node from which the match was found
108                anchor: Node
109                # Maps nodes in the pattern subgraph to nodes in the larger graph
110                nodes_map: Dict[Node, Node]
111
112    Examples:
113
114    .. code-block:: python
115
116        import torch
117        from torch.fx import symbolic_trace, subgraph_rewriter
118
119        class M(torch.nn.Module):
120            def __init__(self) -> None:
121                super().__init__()
122
123            def forward(self, x, w1, w2):
124                m1 = torch.cat([w1, w2]).sum()
125                m2 = torch.cat([w1, w2]).sum()
126                return x + torch.max(m1) + torch.max(m2)
127
128        def pattern(w1, w2):
129            return torch.cat([w1, w2]).sum()
130
131        def replacement(w1, w2):
132            return torch.stack([w1, w2])
133
134        traced_module = symbolic_trace(M())
135
136        subgraph_rewriter.replace_pattern(traced_module, pattern, replacement)
137
138    The above code will first match ``pattern`` in the ``forward``
139    method of ``traced_module``. Pattern-matching is done based on
140    use-def relationships, not node names. For example, if you had
141    ``p = torch.cat([a, b])`` in ``pattern``, you could match
142    ``m = torch.cat([a, b])`` in the original ``forward`` function,
143    despite the variable names being different (``p`` vs ``m``).
144
145    The ``return`` statement in ``pattern`` is matched based on its
146    value only; it may or may not match to the ``return`` statement in
147    the larger graph. In other words, the pattern doesn't have to extend
148    to the end of the larger graph.
149
150    When the pattern is matched, it will be removed from the larger
151    function and replaced by ``replacement``. If there are multiple
152    matches for ``pattern`` in the larger function, each non-overlapping
153    match will be replaced. In the case of a match overlap, the first
154    found match in the set of overlapping matches will be replaced.
155    ("First" here being defined as the first in a topological ordering
156    of the Nodes' use-def relationships. In most cases, the first Node
157    is the parameter that appears directly after ``self``, while the
158    last Node is whatever the function returns.)
159
160    One important thing to note is that the parameters of the
161    ``pattern`` Callable must be used in the Callable itself,
162    and the parameters of the ``replacement`` Callable must match
163    the pattern. The first rule is why, in the above code block, the
164    ``forward`` function has parameters ``x, w1, w2``, but the
165    ``pattern`` function only has parameters ``w1, w2``. ``pattern``
166    doesn't use ``x``, so it shouldn't specify ``x`` as a parameter.
167    As an example of the second rule, consider replacing
168
169    .. code-block:: python
170
171        def pattern(x, y):
172            return torch.neg(x) + torch.relu(y)
173
174    with
175
176    .. code-block:: python
177
178        def replacement(x, y):
179            return torch.relu(x)
180
181    In this case, ``replacement`` needs the same number of parameters
182    as ``pattern`` (both ``x`` and ``y``), even though the parameter
183    ``y`` isn't used in ``replacement``.
184
185    After calling ``subgraph_rewriter.replace_pattern``, the generated
186    Python code looks like this:
187
188    .. code-block:: python
189
190        def forward(self, x, w1, w2):
191            stack_1 = torch.stack([w1, w2])
192            sum_1 = stack_1.sum()
193            stack_2 = torch.stack([w1, w2])
194            sum_2 = stack_2.sum()
195            max_1 = torch.max(sum_1)
196            add_1 = x + max_1
197            max_2 = torch.max(sum_2)
198            add_2 = add_1 + max_2
199            return add_2
200    """
201    match_and_replacements = _replace_pattern(gm, pattern, replacement)
202    return [Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements]
203
204
205# Experimental API, not backward compatible
206@compatibility(is_backward_compatible=False)
207def replace_pattern_with_filters(
208    gm: GraphModule,
209    pattern: Union[Callable, Graph, GraphModule],
210    replacement: Union[Callable, Graph, GraphModule],
211    match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
212    ignore_literals: bool = False,
213) -> List[ReplacedPatterns]:
214    """
215    See replace_pattern for documentation. This function is an overload with an additional match_filter argument.
216
217    Args:
218        ``match_filters``: A list of functions that take in
219            (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating
220            whether the match satisfies the condition.
221            See matcher_utils.py for definition of InternalMatch.
222    """
223
224    return _replace_pattern(gm, pattern, replacement, match_filters, ignore_literals)
225
226
227def _replace_pattern(
228    gm: GraphModule,
229    pattern: Union[Callable, Graph, GraphModule],
230    replacement: Union[Callable, Graph, GraphModule],
231    match_filters: Optional[List[Callable[["InternalMatch", Graph, Graph], bool]]] = None,
232    ignore_literals: bool = False,
233) -> List[ReplacedPatterns]:
234
235    from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch
236
237    if match_filters is None:
238        match_filters = []
239
240    # Get the graphs for `gm`, `pattern`, `replacement`
241    original_graph: Graph = gm.graph
242
243    if isinstance(pattern, GraphModule):
244        pattern_graph = pattern.graph
245    elif isinstance(pattern, Graph):
246        pattern_graph = pattern
247    else:
248        pattern_graph = symbolic_trace(pattern).graph
249
250    if isinstance(replacement, GraphModule):
251        replacement_graph = replacement.graph
252    elif isinstance(replacement, Graph):
253        replacement_graph = replacement
254    else:
255        replacement_graph = symbolic_trace(replacement).graph
256
257    matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False,
258                              remove_overlapping_matches=True, ignore_literals=ignore_literals)
259    _matches: List[InternalMatch] = matcher.match(original_graph)
260
261    # Filter out matches that don't match the filter
262    _matches = [
263        m for m in _matches
264        if all(match_filter(m, original_graph, pattern_graph)
265               for match_filter in match_filters)
266    ]
267
268    replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"]
269
270    # As we progressively replace nodes, we'll need to keep track of how the match results should change
271    match_changed_node: Dict[Node, Node] = {}
272
273    match_and_replacements = []
274    for match in _matches:
275
276        # Build connecting between replacement graph's input and original graph input producer node
277
278        # Initialize `val_map` with mappings from placeholder nodes in
279        # `replacement` to their corresponding node in `original_graph`
280        assert len(match.placeholder_nodes) == len(replacement_placeholders)
281        val_map: Dict[Node, Node] = {}
282        for rn, gn in zip(replacement_placeholders, match.placeholder_nodes):
283            if isinstance(gn, Node):
284                val_map[rn] = match_changed_node.get(gn, gn)
285                if gn != val_map[rn]:
286                    # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn
287                    gn_ind = match.placeholder_nodes.index(gn)
288                    match.placeholder_nodes[gn_ind] = match_changed_node[gn]
289                    map_key = list(match.nodes_map.keys())[list(match.nodes_map.values()).index(gn)]
290                    match.nodes_map[map_key] = match_changed_node[gn]
291            else:
292                val_map[rn] = gn
293
294        # Copy the replacement graph over
295        user_nodes: Set[Node] = set()
296        for n in match.returning_nodes:
297            user_nodes.update(n.users)
298        assert user_nodes, "The returning_nodes should have at least one user node"
299
300        if len(user_nodes) == 1:
301            first_user_node = next(iter(user_nodes))
302        else:
303            # If there are multiple user nodes, we need to find the first user node
304            # in the current execution order of the `original_graph`
305            for n in original_graph.nodes:
306                if n in user_nodes:
307                    first_user_node = n
308                    break
309
310        with original_graph.inserting_before(first_user_node):  # type: ignore[possibly-undefined]
311            copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map)
312
313        if isinstance(copied_returning_nodes, Node):
314            copied_returning_nodes = (copied_returning_nodes, )
315
316        # Get a list of nodes that have been replaced into the graph
317        replacement_nodes: List[Node] = [v for v in val_map.values() if v not in match.placeholder_nodes]
318
319        # Hook the output Node of the replacement subgraph in to the
320        # original Graph at the correct location
321        assert len(match.returning_nodes) == len(copied_returning_nodes)  # type: ignore[arg-type]
322        for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes):  # type: ignore[arg-type]
323            gn.replace_all_uses_with(copied_node)
324            match_changed_node[gn] = copied_node
325        # Remove the original nodes
326        for node in reversed(pattern_graph.nodes):
327            if node.op != "placeholder" and node.op != "output":
328                gn = match.nodes_map[node]
329                gm.graph.erase_node(gn)
330
331        match_and_replacements.append(
332            ReplacedPatterns(
333                anchor=match.anchors[0],
334                nodes_map=match.nodes_map,
335                replacements=replacement_nodes
336            )
337        )
338
339    # Update the passed-in GraphModule to reflect the new state of
340    # `original_graph`
341    gm.recompile()
342
343    # If `replacement` was an nn.Module, we'll need to make sure that
344    # all the submodules have been copied over correctly
345    if isinstance(replacement, torch.nn.Module):
346        _replace_attributes(gm, replacement)
347
348    return match_and_replacements
349