1# mypy: allow-untyped-defs 2 3import torch 4from torch._higher_order_ops.wrap import wrap_with_set_grad_enabled 5 6from ..utils import node_inline_, nodes_filter, nodes_first, nodes_map, sequential_split 7from .replace_with_hop_pass_util import ( 8 _replace_with_hop_helper, 9 _replace_with_hop_pass_helper, 10 _sequential_split_and_maybe_inline_subgraphs_helper, 11) 12 13 14def _is_set_grad_enabled_node(node: torch.fx.Node): 15 return ( 16 node 17 and node.op == "call_function" 18 and node.target == torch._C._set_grad_enabled 19 ) 20 21 22def _is_set_grad_enabled_sub_mod(node: torch.fx.Node, omit_if_same_with_ambient=False): 23 if node.op == "call_module": 24 assert isinstance(node.target, str) 25 subgm = getattr(node.graph.owning_module, node.target) 26 first_non_ph = nodes_first( 27 subgm.graph.nodes, lambda node: node.op != "placeholder" 28 ) 29 if ( 30 first_non_ph 31 and first_non_ph.op == "call_function" 32 and first_non_ph.target == torch._C._set_grad_enabled 33 ): 34 return ( 35 first_non_ph.args[0] != torch.is_grad_enabled() 36 if omit_if_same_with_ambient 37 else True 38 ) 39 return False 40 41 42def _replace_with_hop(node: torch.fx.Node): 43 assert node.op == "call_module" 44 graph: torch.fx.Graph = node.graph 45 gm: torch.fx.GraphModule = graph.owning_module 46 assert isinstance(node.target, str) 47 sub_gm = getattr(gm, node.target) 48 sub_graph = sub_gm.graph 49 set_grad_nodes = nodes_filter(sub_graph.nodes, _is_set_grad_enabled_node) 50 if len(set_grad_nodes) > 0: 51 assert len(set_grad_nodes) == 1 52 set_grad_node = set_grad_nodes[0] 53 _replace_with_hop_helper( 54 node, set_grad_node, _is_set_grad_enabled_node, wrap_with_set_grad_enabled 55 ) 56 sub_graph.erase_node(set_grad_node) 57 58 59def _remove_set_grad_and_inline(node: torch.fx.Node): 60 assert node.op == "call_module" 61 graph: torch.fx.Graph = node.graph 62 gm: torch.fx.GraphModule = graph.owning_module 63 assert isinstance(node.target, str) 64 sub_gm = getattr(gm, node.target) 65 sub_graph = sub_gm.graph 66 nodes_map( 67 sub_graph.nodes, 68 lambda n: sub_graph.erase_node(n) if _is_set_grad_enabled_node(n) else n, 69 ) 70 node_inline_(node) 71 72 73def _sequential_split_and_maybe_inline_subgraphs( 74 gm: torch.fx.GraphModule, graph_signature 75): 76 """ 77 Helper function for replace_set_grad_with_hop_pass(). 78 Split the graph module into multiple subgraphs based on the set_grad_enabled nodes. 79 For each subgraph, decides whether to construct a HOO subgraph, or inline the calls 80 back into the parent graph module. 81 """ 82 need_replacing = any(_is_set_grad_enabled_node(node) for node in gm.graph.nodes) 83 if not need_replacing: 84 return gm, graph_signature 85 86 # sequential_split returns a new graph module that could have different output 87 # args names. We need to fix the graph signature. 88 new_gm = sequential_split(gm, _is_set_grad_enabled_node) 89 90 def _maybe_inline_or_replace_with_hop(node: torch.fx.Node): 91 if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True): 92 _replace_with_hop(node) 93 else: 94 _remove_set_grad_and_inline(node) 95 96 return _sequential_split_and_maybe_inline_subgraphs_helper( 97 new_gm, graph_signature, _maybe_inline_or_replace_with_hop 98 ) 99 100 101def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature): 102 """ 103 Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and 104 then recursively call itself on each of the submodules. 105 """ 106 return _replace_with_hop_pass_helper( 107 gm, 108 graph_signature, 109 _sequential_split_and_maybe_inline_subgraphs, 110 ) 111