xref: /aosp_15_r20/external/pytorch/torch/_export/passes/replace_set_grad_with_hop_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2
3import torch
4from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled
5
6from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split
7from .replace_with_hop_pass_util import (
8    _replace_with_hop_helper,
9    _replace_with_hop_pass_helper,
10    _sequential_split_and_maybe_inline_subgraphs_helper,
11)
12
13
14def _is_set_grad_enabled_node(node: torch.fx.Node):
15    return (
16        node
17        and node.op == "call_function"
18        and node.target == torch._C._set_grad_enabled
19    )
20
21
22def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False):
23    if node.op == "call_module":
24        assert isinstance(node.target, str)
25        subgm = getattr(node.graph.owning_module, node.target)
26        first_non_ph = nodes_first(
27            subgm.graph.nodes, lambda node: node.op != "placeholder"
28        )
29        if (
30            first_non_ph
31            and first_non_ph.op == "call_function"
32            and first_non_ph.target == torch._C._set_grad_enabled
33        ):
34            return (
35                first_non_ph.args[0] != torch.is_grad_enabled()
36                if omit_if_same_with_ambient
37                else True
38            )
39    return False
40
41
42def _replace_with_hop(node: torch.fx.Node):
43    assert node.op == "call_module"
44    graph: torch.fx.Graph = node.graph
45    gm: torch.fx.GraphModule = graph.owning_module
46    assert isinstance(node.target, str)
47    sub_gm = getattr(gm, node.target)
48    sub_graph = sub_gm.graph
49    set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node)
50    if len(set_grad_nodes) > 0:
51        assert len(set_grad_nodes) == 1
52        set_grad_node = set_grad_nodes[0]
53        _replace_with_hop_helper(
54            node, set_grad_node, _is_set_grad_enabled_node, wrap_with_set_grad_enabled
55        )
56        sub_graph.erase_node(set_grad_node)
57
58
59def _remove_set_grad_and_inline(node: torch.fx.Node):
60    assert node.op == "call_module"
61    graph: torch.fx.Graph = node.graph
62    gm: torch.fx.GraphModule = graph.owning_module
63    assert isinstance(node.target, str)
64    sub_gm = getattr(gm, node.target)
65    sub_graph = sub_gm.graph
66    nodes_map(
67        sub_graph.nodes,
68        lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n,
69    )
70    node_inline_(node)
71
72
73def _sequential_split_and_maybe_inline_subgraphs(
74    gm: torch.fx.GraphModule, graph_signature
75):
76    """
77    Helper function for replace_set_grad_with_hop_pass().
78    Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
79    For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
80    back into the parent graph module.
81    """
82    need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes)
83    if not need_replacing:
84        return gm, graph_signature
85
86    # sequential_split returns a new graph module that could have different output
87    # args names. We need to fix the graph signature.
88    new_gm = sequential_split(gm, _is_set_grad_enabled_node)
89
90    def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
91        if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
92            _replace_with_hop(node)
93        else:
94            _remove_set_grad_and_inline(node)
95
96    return _sequential_split_and_maybe_inline_subgraphs_helper(
97        new_gm, graph_signature, _maybe_inline_or_replace_with_hop
98    )
99
100
101def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
102    """
103    Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
104    then recursively call itself on each of the submodules.
105    """
106    return _replace_with_hop_pass_helper(
107        gm,
108        graph_signature,
109        _sequential_split_and_maybe_inline_subgraphs,
110    )
111