1# mypy: allow-untyped-defs 2 3import contextlib 4import copy 5import operator 6from typing import Callable 7 8import torch 9from torch._ops import HigherOrderOperator 10 11from ..utils import node_replace_, nodes_map 12 13 14def _replace_with_hop_helper( 15 node: torch.fx.Node, 16 enter_block_node: torch.fx.Node, 17 node_filter: Callable, 18 wrap_hoo: HigherOrderOperator, 19): 20 graph: torch.fx.Graph = node.graph 21 gm: torch.fx.GraphModule = graph.owning_module 22 assert isinstance(node.target, str) 23 sub_gm = getattr(gm, node.target) 24 25 def set_hoo_node_meta(call_func_node): 26 call_func_node.meta["nn_module_stack"] = copy.copy( 27 enter_block_node.meta.get("nn_module_stack", {}) 28 ) 29 call_func_node.meta["torch_fn"] = ( 30 f"{wrap_hoo.__name__}", 31 f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}", 32 ) 33 if isinstance(output_args, (tuple, list)): 34 call_func_node.meta["val"] = tuple(arg.meta["val"] for arg in output_args) 35 elif isinstance(output_args, torch.fx.Node): 36 call_func_node.meta["val"] = (output_args.meta["val"],) 37 38 with graph.inserting_before(node): 39 get_attr_node = graph.get_attr(node.target) 40 get_attr_node.meta["nn_module_stack"] = copy.copy( 41 enter_block_node.meta.get("nn_module_stack", {}) 42 ) 43 output_node = next(iter(reversed(sub_gm.graph.nodes)), None) 44 # Split_module pass intentially doesn't add output node 45 # if the graph doesn't return anything. 46 # TODO (tmanlaibaatar) Figure out if this is right behaviour 47 # for split_module 48 if isinstance(output_node, torch.fx.Node) and output_node.op != "output": 49 output_node = None 50 if output_node is not None: 51 assert len(output_node.args) == 1 52 output_args = output_node.args[0] 53 enter_block_node_args = enter_block_node.args 54 if isinstance(output_args, (tuple, list)): 55 call_func_node = graph.call_function( 56 wrap_hoo, 57 (*enter_block_node_args, get_attr_node, *node.args), 58 {}, 59 ) 60 # Create the metadata 61 set_hoo_node_meta(call_func_node) 62 node_replace_(node, call_func_node) 63 64 # Rename the name of getitem nodes to the actual name of its contents 65 # for passing verifier and better readability, also propagate metadata 66 for get_item_node in call_func_node.users.keys(): 67 idx: int = get_item_node.args[1] # type: ignore[assignment] 68 output_node = output_args[idx] 69 get_item_node._rename(output_node.name) 70 get_item_node.meta = output_node.meta 71 72 elif isinstance(output_args, torch.fx.Node): 73 call_func_node = graph.create_node( 74 "call_function", 75 wrap_hoo, 76 (*enter_block_node_args, get_attr_node, *node.args), 77 {}, 78 output_args.name, 79 ) 80 # Modify the subgraph to output a singleton list. 81 output_node.args = ((output_args,),) 82 # Add in an extra `getitem(wrap_hoo, 0)` node to the toplevel graph. 83 get_item_node = graph.create_node( 84 "call_function", 85 operator.getitem, 86 (call_func_node, 0), 87 {}, 88 ) 89 # Create the metadata 90 get_item_node.meta = output_args.meta 91 set_hoo_node_meta(call_func_node) 92 node_replace_(node, get_item_node) 93 else: 94 raise NotImplementedError( 95 f"repalce_with_hop_pass doesnt' support output type {type(output_args)}" 96 ) 97 else: 98 # TODO (shangdiy): remove this line, since the export graph can be non-functional 99 node.graph.erase_node(node) 100 101 102def _sequential_split_and_maybe_inline_subgraphs_helper( 103 new_gm: torch.fx.GraphModule, 104 graph_signature, 105 maybe_inline_or_replace_with_hop: Callable[[torch.fx.Node], None], 106): 107 """ 108 Helper function for replacing graph nodse with higher order nodes. 109 For each subgraph in `new_gm`, decides whether to construct a HOO subgraph, or inline the calls 110 back into the parent graph module, depending on `maybe_inline_or_replace_with_hop`. 111 """ 112 # new_gm is a new graph module that could have different output args names. 113 # We need to fix the graph signature. 114 replace_ctx = contextlib.nullcontext() 115 new_signature = None 116 if graph_signature is not None: 117 # Cannot deep copy a real ScriptObject, which is referenced 118 # in the FakeScriptObject. Copy should be good enough to guard 119 # against accidental mutation to original graph_signature. 120 new_signature = copy.copy(graph_signature) 121 new_gm_out_node = next(reversed(new_gm.graph.find_nodes(op="output"))) 122 assert new_gm_out_node.op == "output" and len(new_gm_out_node.args[0]) == len( 123 new_signature.output_specs 124 ) 125 for arg_node, out_spec in zip( 126 new_gm_out_node.args[0], new_signature.output_specs 127 ): 128 if arg_node is None: 129 assert out_spec.arg.value is None 130 elif ( 131 isinstance(arg_node, torch.fx.Node) 132 and out_spec.arg.name != arg_node.name 133 ): 134 out_spec.arg.name = arg_node.name 135 136 replace_ctx = new_gm._set_replace_hook(new_signature.get_replace_hook()) # type: ignore[assignment] 137 138 with replace_ctx: 139 nodes_map( 140 list(new_gm.graph.nodes), 141 lambda node: ( 142 maybe_inline_or_replace_with_hop(node) 143 if node.op == "call_module" 144 else node 145 ), 146 ) 147 new_gm.recompile() 148 return new_gm, new_signature 149 150 151def _replace_with_hop_pass_helper( 152 gm: torch.fx.GraphModule, 153 graph_signature, 154 sequential_split_and_maybe_inline_subgraphs: Callable, 155): 156 """ 157 Split gm into sub-graph-modules using `sequential_split_and_maybe_inline_subgraphs`, and 158 then recursively call itself on each of the submodules. 159 """ 160 new_gm, new_signature = sequential_split_and_maybe_inline_subgraphs( 161 gm, graph_signature 162 ) 163 # recursively call 164 for node in new_gm.graph.nodes: 165 if node.op == "get_attr": 166 subgm = getattr(new_gm, node.target) 167 if not isinstance(subgm, torch.fx.GraphModule): 168 continue 169 new_subgm, _ = _replace_with_hop_pass_helper( 170 subgm, 171 None, 172 sequential_split_and_maybe_inline_subgraphs, 173 ) 174 setattr(new_gm, node.target, new_subgm) 175 176 new_gm.recompile() 177 new_gm.graph.lint() 178 return new_gm, new_signature 179