1# mypy: allow-untyped-defs 2from typing import List 3 4import torch 5from torch._higher_order_ops.wrap import wrap_with_autocast 6 7from ..utils import node_inline_, nodes_filter, nodes_first, sequential_split 8from .replace_with_hop_pass_util import ( 9 _replace_with_hop_helper, 10 _replace_with_hop_pass_helper, 11 _sequential_split_and_maybe_inline_subgraphs_helper, 12) 13 14 15def _is_autocast_node(node: torch.fx.Node): 16 return ( 17 node 18 and node.op == "call_function" 19 and node.target 20 in [ 21 torch.amp.autocast_mode._enter_autocast, 22 torch.amp.autocast_mode._exit_autocast, 23 ] 24 ) 25 26 27def _is_enter_autocast_node(node: torch.fx.Node): 28 return ( 29 node 30 and node.op == "call_function" 31 and node.target == torch.amp.autocast_mode._enter_autocast 32 ) 33 34 35def _is_exit_autocast_node(node: torch.fx.Node): 36 return ( 37 node 38 and node.op == "call_function" 39 and node.target == torch.amp.autocast_mode._exit_autocast 40 ) 41 42 43def _is_autocast_sub_mod(node: torch.fx.Node): 44 """ 45 Check if the first non-placeholder node is `torch.amp.autocast_mode._enter_autocast`. 46 """ 47 if node.op == "call_module": 48 assert isinstance(node.target, str) 49 subgm = getattr(node.graph.owning_module, node.target) 50 first_non_ph = nodes_first( 51 subgm.graph.nodes, lambda node: node.op != "placeholder" 52 ) 53 if ( 54 first_non_ph 55 and first_non_ph.op == "call_function" 56 and first_non_ph.target == torch.amp.autocast_mode._enter_autocast 57 ): 58 # TODO: check if current auto-cast type is the same as the args of 59 # _enter_autocast. If so, return False, i.e. do not create a submodule. 60 return True 61 return False 62 63 64def _check_valid_autocast_block(enter_autocast_node, exit_autocast_node): 65 assert _is_enter_autocast_node(enter_autocast_node) 66 assert _is_exit_autocast_node(exit_autocast_node) 67 assert exit_autocast_node.args[0] == enter_autocast_node 68 69 70def _replace_with_hop(node: torch.fx.Node): 71 assert node.op == "call_module" 72 graph: torch.fx.Graph = node.graph 73 gm: torch.fx.GraphModule = graph.owning_module 74 assert isinstance(node.target, str) 75 sub_gm = getattr(gm, node.target) 76 sub_graph = sub_gm.graph 77 autocast_nodes = nodes_filter(sub_graph.nodes, _is_autocast_node) 78 if len(autocast_nodes) > 0: 79 assert len(autocast_nodes) > 1 # need at least an enter node and an exist node 80 enter_autocast_node = autocast_nodes[0] 81 exit_autocast_node = autocast_nodes[-1] 82 _check_valid_autocast_block(enter_autocast_node, exit_autocast_node) 83 84 _replace_with_hop_helper( 85 node, enter_autocast_node, _is_autocast_node, wrap_with_autocast 86 ) 87 sub_graph.erase_node(exit_autocast_node) 88 sub_graph.erase_node(enter_autocast_node) 89 90 91def _split_autocast(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: 92 """ 93 split_autocast creates a new graph module that splits the input graph module into multiple submodules 94 based on the `_enter_autocast` and `_exit_autocast` nodes. It doesn't mutate the input graph module. 95 96 Nodes between the **outer-most** `_enter_autocast` and `_exit_autocast(_enter_autocast)` are splitted 97 into a submodule. Nested autocast regions are not splitted. 98 `_enter_autocast` and `_exit_autocast(_enter_autocast)` nodes are in the submodule as well. 99 100 Below is an example of splitting. A, B, C, D, E are blocks of non-autocast nodes in the original graph 101 module. Nodes marked with the same number are grouped into the same submodule. 102 A # 0 103 enter_autocast # 1 104 B # 1 105 exit_autocast # 1 106 C # 2 107 enter_autocast # 3 108 D # 3 109 exit_autocast # 3 110 E # 4 111 """ 112 enter_autocast_node_stack: List[torch.fx.Node] = [] 113 first_node_after_outer_most_exit: bool = False 114 115 def node_call_back(node: torch.fx.Node): 116 nonlocal enter_autocast_node_stack, first_node_after_outer_most_exit 117 if first_node_after_outer_most_exit or ( 118 len(enter_autocast_node_stack) == 0 and _is_enter_autocast_node(node) 119 ): 120 assert len(enter_autocast_node_stack) == 0 121 first_node_after_outer_most_exit = False 122 if _is_enter_autocast_node(node): 123 enter_autocast_node_stack.append(node) 124 return True 125 if _is_exit_autocast_node(node): 126 assert len(enter_autocast_node_stack) > 0 127 last_enter_autocast_node = enter_autocast_node_stack.pop() 128 assert node.args[0] == last_enter_autocast_node 129 if len(enter_autocast_node_stack) == 0: 130 # next node should be in the next submodule since 131 # autocast block ends 132 first_node_after_outer_most_exit = True 133 return False 134 135 return sequential_split(gm, node_call_back) 136 137 138def _sequential_split_and_maybe_inline_subgraphs( 139 gm: torch.fx.GraphModule, graph_signature 140): 141 """ 142 Helper function for replace_autocast_with_hop_pass(). 143 Split the graph module into multiple subgraphs based on the autocast nodes. 144 For each subgraph, decides whether to construct a HOO subgraph, or inline the calls 145 back into the parent graph module. 146 Nodes between `_enter_autocast` and `_exit_autocast(_enter_autocast)` are considered 147 as a subgraph. 148 """ 149 need_replacing = any(_is_autocast_node(node) for node in gm.graph.nodes) 150 if not need_replacing: 151 return gm, graph_signature 152 153 # split_autocast returns a new graph module that could have different output 154 # args names. We need to fix the graph signature in `_sequential_split_and_maybe_inline_subgraphs_helper`. 155 new_gm = _split_autocast(gm) 156 157 def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): 158 if _is_autocast_sub_mod(node): 159 _replace_with_hop(node) 160 else: 161 assert node.op == "call_module" 162 assert isinstance(node.target, str) 163 node_inline_(node) 164 165 return _sequential_split_and_maybe_inline_subgraphs_helper( 166 new_gm, graph_signature, _maybe_inline_or_replace_with_hop 167 ) 168 169 170def replace_autocast_with_hop_pass(gm: torch.fx.GraphModule, graph_signature): 171 """ 172 Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and 173 then recursively call itself on each of the submodules. 174 """ 175 return _replace_with_hop_pass_helper( 176 gm, 177 graph_signature, 178 _sequential_split_and_maybe_inline_subgraphs, 179 ) 180