xref: /aosp_15_r20/external/pytorch/torch/_export/passes/replace_with_hop_pass_util.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import contextlib
4import copy
5import operator
6from typing import Callable
7
8import torch
9from torch._ops import HigherOrderOperator
10
11from ..utils import node_replace_, nodes_map
12
13
14def _replace_with_hop_helper(
15    node: torch.fx.Node,
16    enter_block_node: torch.fx.Node,
17    node_filter: Callable,
18    wrap_hoo: HigherOrderOperator,
19):
20    graph: torch.fx.Graph = node.graph
21    gm: torch.fx.GraphModule = graph.owning_module
22    assert isinstance(node.target, str)
23    sub_gm = getattr(gm, node.target)
24
25    def set_hoo_node_meta(call_func_node):
26        call_func_node.meta["nn_module_stack"] = copy.copy(
27            enter_block_node.meta.get("nn_module_stack", {})
28        )
29        call_func_node.meta["torch_fn"] = (
30            f"{wrap_hoo.__name__}",
31            f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
32        )
33        if isinstance(output_args, (tuple, list)):
34            call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args)
35        elif isinstance(output_args, torch.fx.Node):
36            call_func_node.meta["val"] = (output_args.meta["val"],)
37
38    with graph.inserting_before(node):
39        get_attr_node = graph.get_attr(node.target)
40        get_attr_node.meta["nn_module_stack"] = copy.copy(
41            enter_block_node.meta.get("nn_module_stack", {})
42        )
43        output_node = next(iter(reversed(sub_gm.graph.nodes)), None)
44        # Split_module pass intentially doesn't add output node
45        # if the graph doesn't return anything.
46        # TODO (tmanlaibaatar) Figure out if this is right behaviour
47        # for split_module
48        if isinstance(output_node, torch.fx.Node) and output_node.op != "output":
49            output_node = None
50        if output_node is not None:
51            assert len(output_node.args) == 1
52            output_args = output_node.args[0]
53            enter_block_node_args = enter_block_node.args
54            if isinstance(output_args, (tuple, list)):
55                call_func_node = graph.call_function(
56                    wrap_hoo,
57                    (*enter_block_node_args, get_attr_node, *node.args),
58                    {},
59                )
60                # Create the metadata
61                set_hoo_node_meta(call_func_node)
62                node_replace_(node, call_func_node)
63
64                # Rename the name of getitem nodes to the actual name of its contents
65                # for passing verifier and better readability, also propagate metadata
66                for get_item_node in call_func_node.users.keys():
67                    idx: int = get_item_node.args[1]  # type: ignore[assignment]
68                    output_node = output_args[idx]
69                    get_item_node._rename(output_node.name)
70                    get_item_node.meta = output_node.meta
71
72            elif isinstance(output_args, torch.fx.Node):
73                call_func_node = graph.create_node(
74                    "call_function",
75                    wrap_hoo,
76                    (*enter_block_node_args, get_attr_node, *node.args),
77                    {},
78                    output_args.name,
79                )
80                # Modify the subgraph to output a singleton list.
81                output_node.args = ((output_args,),)
82                # Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph.
83                get_item_node = graph.create_node(
84                    "call_function",
85                    operator.getitem,
86                    (call_func_node, 0),
87                    {},
88                )
89                # Create the metadata
90                get_item_node.meta = output_args.meta
91                set_hoo_node_meta(call_func_node)
92                node_replace_(node, get_item_node)
93            else:
94                raise NotImplementedError(
95                    f"repalce_with_hop_pass doesnt' support output type {type(output_args)}"
96                )
97        else:
98            # TODO (shangdiy): remove this line, since the export graph can be non-functional
99            node.graph.erase_node(node)
100
101
102def _sequential_split_and_maybe_inline_subgraphs_helper(
103    new_gm: torch.fx.GraphModule,
104    graph_signature,
105    maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None],
106):
107    """
108    Helper function for replacing graph nodse with higher order nodes.
109    For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls
110    back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`.
111    """
112    # new_gm is a new graph module that could have different output args names.
113    # We need to fix the graph signature.
114    replace_ctx = contextlib.nullcontext()
115    new_signature = None
116    if graph_signature is not None:
117        # Cannot deep copy a real ScriptObject, which is referenced
118        # in the FakeScriptObject. Copy should be good enough to guard
119        # against accidental mutation to original graph_signature.
120        new_signature = copy.copy(graph_signature)
121        new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output")))
122        assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len(
123            new_signature.output_specs
124        )
125        for arg_node, out_spec in zip(
126            new_gm_out_node.args[0], new_signature.output_specs
127        ):
128            if arg_node is None:
129                assert out_spec.arg.value is None
130            elif (
131                isinstance(arg_node, torch.fx.Node)
132                and out_spec.arg.name != arg_node.name
133            ):
134                out_spec.arg.name = arg_node.name
135
136        replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook())  # type: ignore[assignment]
137
138    with replace_ctx:
139        nodes_map(
140            list(new_gm.graph.nodes),
141            lambda node: (
142                maybe_inline_or_replace_with_hop(node)
143                if node.op == "call_module"
144                else node
145            ),
146        )
147    new_gm.recompile()
148    return new_gm, new_signature
149
150
151def _replace_with_hop_pass_helper(
152    gm: torch.fx.GraphModule,
153    graph_signature,
154    sequential_split_and_maybe_inline_subgraphs: Callable,
155):
156    """
157    Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
158    then recursively call itself on each of the submodules.
159    """
160    new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs(
161        gm, graph_signature
162    )
163    # recursively call
164    for node in new_gm.graph.nodes:
165        if node.op == "get_attr":
166            subgm = getattr(new_gm, node.target)
167            if not isinstance(subgm, torch.fx.GraphModule):
168                continue
169            new_subgm, _ = _replace_with_hop_pass_helper(
170                subgm,
171                None,
172                sequential_split_and_maybe_inline_subgraphs,
173            )
174            setattr(new_gm, node.target, new_subgm)
175
176    new_gm.recompile()
177    new_gm.graph.lint()
178    return new_gm, new_signature
179