xref: /aosp_15_r20/external/pytorch/torch/onnx/_internal/fx/passes/_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2"""Common utility functions for FX passes.
3
4These functions should NOT be directly invoked outside of `passes` package.
5"""
6
7from __future__ import annotations
8
9import collections
10import re
11from typing import Callable
12
13import torch.fx
14import torch.fx.traceback as fx_traceback
15
16
17def wrap_graph_module_for_node_meta_preservation(
18    graph_module: torch.fx.GraphModule,
19) -> Callable:
20    """Wrap a GraphModule with contexts to preserve node meta information, such as stacktrace info.
21
22    This is typically useful before calling `make_fx`. Without this wrapper, the
23    stacktrace information will be lost afterwards.
24    """
25
26    def wrapped(*args):
27        with fx_traceback.preserve_node_meta():
28            return torch.fx.Interpreter(graph_module).run(*args)
29
30    return wrapped
31
32
33def _get_node_base_name(node_name: str) -> tuple[str, int | None]:
34    pattern = r"(.*)\.(\d+)"
35    match = re.match(pattern, node_name)
36    if match is not None:
37        base_name, count_str = match.groups()
38        return base_name, int(count_str)
39    return node_name, None
40
41
42def set_node_name(
43    node: torch.fx.Node,
44    new_name: str,
45    name_to_node_cache: dict[str, torch.fx.Node],
46):
47    """Safely set the unique name of a node.
48
49    If the new name is already taken by another node, the name of the other node will be
50    updated. If `new_name` is a string of format f"{base_name}.{count}", where `count`
51    is an integer, the other node will be renamed as f"{base_name}.{count+1}". If not,
52    the other node will be renamed as "{new_name}.1". This function will iteratively
53    update the names until there is no conflict.
54
55    ``name_to_node_cache`` is required as an argument to avoid recomputation. The caller
56    is responsible for ensuring the cache is accurate and in sync with the owning module
57    of the node. The values in the cache will be updated accordingly.
58
59    Args:
60        node: The node to update.
61        new_name: The new name to use.
62        name_to_node_cache: A cache of node names to nodes.
63    """
64    module = node.graph.owning_module
65    node_name_to_set = collections.deque([(node, new_name)])
66
67    while node_name_to_set:
68        node, new_name = node_name_to_set.pop()
69        if new_name in name_to_node_cache and name_to_node_cache[new_name] != node:
70            base_name, postfix_count = _get_node_base_name(new_name)
71            if postfix_count is None:
72                postfix_count = 0
73            node_name_to_set.append(
74                (name_to_node_cache[new_name], f"{base_name}.{postfix_count + 1}")
75            )
76        node.name = new_name
77        name_to_node_cache[new_name] = node
78
79
80def replace_placeholder_name_and_target(
81    module: torch.fx.GraphModule, reference_module: torch.fx.GraphModule
82):
83    """Replace the argument names in module with those in reference_module.
84
85    This function assumes the two modules have the same signature structure.
86    The caller is responsible for ensuring this. Otherwise, the behavior of this
87    function is undefined. This function only does minimal sanity check that the two
88    modules have the same number of arguments.
89
90    Name conflicts between new names and existing node names in the graph are handled.
91    Check the documentation of :func:`set_node_name` for more details.
92
93    Raises:
94        RuntimeError: If the two modules have different number of arguments.
95    """
96    placeholders = [node for node in module.graph.nodes if node.op == "placeholder"]
97    reference_placeholders = [
98        node for node in reference_module.graph.nodes if node.op == "placeholder"
99    ]
100
101    if len(placeholders) != len(reference_placeholders):
102        raise RuntimeError(
103            "The two modules have different number of arguments. "
104            f"module: {len(placeholders)}, reference_module: {len(reference_placeholders)}"
105        )
106
107    name_to_node: dict[str, torch.fx.Node] = {}
108    for node in module.graph.nodes:
109        name_to_node[node.name] = node
110
111    for placeholder, reference_placeholder in zip(placeholders, reference_placeholders):
112        placeholder.target = reference_placeholder.target
113        set_node_name(placeholder, reference_placeholder.name, name_to_node)
114
115    module.recompile()
116