xref: /aosp_15_r20/external/pytorch/torch/_export/passes/replace_autocast_with_hop_pass.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import List
3
4import torch
5from torch._higher_order_ops.wrap import wrap_with_autocast
6
7from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split
8from .replace_with_hop_pass_util import (
9    _replace_with_hop_helper,
10    _replace_with_hop_pass_helper,
11    _sequential_split_and_maybe_inline_subgraphs_helper,
12)
13
14
15def _is_autocast_node(node: torch.fx.Node):
16    return (
17        node
18        and node.op == "call_function"
19        and node.target
20        in [
21            torch.amp.autocast_mode._enter_autocast,
22            torch.amp.autocast_mode._exit_autocast,
23        ]
24    )
25
26
27def _is_enter_autocast_node(node: torch.fx.Node):
28    return (
29        node
30        and node.op == "call_function"
31        and node.target == torch.amp.autocast_mode._enter_autocast
32    )
33
34
35def _is_exit_autocast_node(node: torch.fx.Node):
36    return (
37        node
38        and node.op == "call_function"
39        and node.target == torch.amp.autocast_mode._exit_autocast
40    )
41
42
43def _is_autocast_sub_mod(node: torch.fx.Node):
44    """
45    Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`.
46    """
47    if node.op == "call_module":
48        assert isinstance(node.target, str)
49        subgm = getattr(node.graph.owning_module, node.target)
50        first_non_ph = nodes_first(
51            subgm.graph.nodes, lambda node: node.op != "placeholder"
52        )
53        if (
54            first_non_ph
55            and first_non_ph.op == "call_function"
56            and first_non_ph.target == torch.amp.autocast_mode._enter_autocast
57        ):
58            # TODO: check if current auto-cast type is the same as the args of
59            # _enter_autocast. If so, return False, i.e. do not create a submodule.
60            return True
61    return False
62
63
64def _check_valid_autocast_block(enter_autocast_node, exit_autocast_node):
65    assert _is_enter_autocast_node(enter_autocast_node)
66    assert _is_exit_autocast_node(exit_autocast_node)
67    assert exit_autocast_node.args[0] == enter_autocast_node
68
69
70def _replace_with_hop(node: torch.fx.Node):
71    assert node.op == "call_module"
72    graph: torch.fx.Graph = node.graph
73    gm: torch.fx.GraphModule = graph.owning_module
74    assert isinstance(node.target, str)
75    sub_gm = getattr(gm, node.target)
76    sub_graph = sub_gm.graph
77    autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node)
78    if len(autocast_nodes) > 0:
79        assert len(autocast_nodes) > 1  # need at least an enter node and an exist node
80        enter_autocast_node = autocast_nodes[0]
81        exit_autocast_node = autocast_nodes[-1]
82        _check_valid_autocast_block(enter_autocast_node, exit_autocast_node)
83
84        _replace_with_hop_helper(
85            node, enter_autocast_node, _is_autocast_node, wrap_with_autocast
86        )
87        sub_graph.erase_node(exit_autocast_node)
88        sub_graph.erase_node(enter_autocast_node)
89
90
91def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule:
92    """
93    split_autocast creates a new graph module that splits the input graph module into multiple submodules
94    based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module.
95
96    Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are splitted
97    into a submodule. Nested autocast regions are not splitted.
98    `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well.
99
100    Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph
101    module. Nodes marked with the same number are grouped into the same submodule.
102    A               # 0
103    enter_autocast  # 1
104    B               # 1
105    exit_autocast   # 1
106    C               # 2
107    enter_autocast  # 3
108    D               # 3
109    exit_autocast   # 3
110    E               # 4
111    """
112    enter_autocast_node_stack: List[torch.fx.Node] = []
113    first_node_after_outer_most_exit: bool = False
114
115    def node_call_back(node: torch.fx.Node):
116        nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit
117        if first_node_after_outer_most_exit or (
118            len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node)
119        ):
120            assert len(enter_autocast_node_stack) == 0
121            first_node_after_outer_most_exit = False
122            if _is_enter_autocast_node(node):
123                enter_autocast_node_stack.append(node)
124            return True
125        if _is_exit_autocast_node(node):
126            assert len(enter_autocast_node_stack) > 0
127            last_enter_autocast_node = enter_autocast_node_stack.pop()
128            assert node.args[0] == last_enter_autocast_node
129            if len(enter_autocast_node_stack) == 0:
130                # next node should be in the next submodule since
131                # autocast block ends
132                first_node_after_outer_most_exit = True
133        return False
134
135    return sequential_split(gm, node_call_back)
136
137
138def _sequential_split_and_maybe_inline_subgraphs(
139    gm: torch.fx.GraphModule, graph_signature
140):
141    """
142    Helper function for replace_autocast_with_hop_pass().
143    Split the graph module into multiple subgraphs based on the autocast nodes.
144    For each subgraph, decides whether to construct a HOO subgraph, or inline the calls
145    back into the parent graph module.
146    Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered
147    as a subgraph.
148    """
149    need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes)
150    if not need_replacing:
151        return gm, graph_signature
152
153    # split_autocast returns a new graph module that could have different output
154    # args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`.
155    new_gm = _split_autocast(gm)
156
157    def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
158        if _is_autocast_sub_mod(node):
159            _replace_with_hop(node)
160        else:
161            assert node.op == "call_module"
162            assert isinstance(node.target, str)
163            node_inline_(node)
164
165    return _sequential_split_and_maybe_inline_subgraphs_helper(
166        new_gm, graph_signature, _maybe_inline_or_replace_with_hop
167    )
168
169
170def replace_autocast_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
171    """
172    Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and
173    then recursively call itself on each of the submodules.
174    """
175    return _replace_with_hop_pass_helper(
176        gm,
177        graph_signature,
178        _sequential_split_and_maybe_inline_subgraphs,
179    )
180