1# mypy: allow-untyped-defs 2import re 3from typing import Callable, Dict, Optional, Set, Union 4 5import torch.fx 6from torch.fx.node import map_arg 7from torch.fx.passes.split_module import split_module 8 9 10__all__ = ['FoldedGraphModule', 'get_unique_attr_name_in_module', 'split_const_subgraphs'] 11 12class FoldedGraphModule(torch.fx.GraphModule): 13 """ 14 FoldedGraphModule is a GraphModule which also contains another 15 `const_subgraph_module` representing a subgraph which has all const attr 16 inputs and which can be run once before running the main standard 17 `graph`. The `const_output_names` are the ordered list names of attrs which 18 represent what each respective output from the const_subgraph should be set 19 on which attrs. 20 """ 21 22 def __init__( 23 self, 24 root: torch.nn.Module, 25 graph: torch.fx.Graph, 26 const_subgraph: Optional[torch.fx.Graph] = None, 27 fx_const_folded_attrs_name: Optional[str] = None, 28 device_for_folded_attrs: str = "cuda", 29 ): 30 super().__init__(root, graph) 31 self.const_subgraph_module = ( 32 None 33 if const_subgraph is None 34 else torch.fx.GraphModule(root, const_subgraph) 35 ) 36 self.has_folding_been_run = False 37 self.fx_const_folded_attrs_name = fx_const_folded_attrs_name 38 self.device_for_folded_attrs = device_for_folded_attrs 39 40 def __call__(self, *args, **kwargs): 41 if not self.has_folding_been_run: 42 self.run_folding() 43 return super().__call__(*args) 44 45 def run_folding(self): 46 # If there's no const subgraph module or attr output names to use, return 47 # early as there is no const folding to perform. 48 if ( 49 self.const_subgraph_module is None 50 or self.fx_const_folded_attrs_name is None 51 ): 52 return 53 54 assert not self.has_folding_been_run 55 self.has_folding_been_run = True 56 57 # Actually run const folding subgraph. Note that single attr const fold 58 # subgraphs output a single Tensor while multiple outputs are returned as 59 # Tuple[Tensor,]. 60 folded_attrs = self.const_subgraph_module() 61 62 def _create_param(i): 63 return torch.nn.Parameter( 64 i.detach().clone() 65 if not isinstance(i, int) 66 else torch.Tensor([i]).to(device=self.device_for_folded_attrs), 67 requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, 68 ) 69 70 params = ( 71 torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) 72 if isinstance(folded_attrs, tuple) 73 else _create_param(folded_attrs) 74 ) 75 setattr(self, self.fx_const_folded_attrs_name, params) 76 77 78def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): 79 """ 80 Given `gm` and some graph module which is called with target name `inline_mod_name`, 81 this helper will inline all of the nodes from that called graph module into `gm`. 82 """ 83 # Fetch the inner graph module that we want to inline inside `gm`. 84 inline_mod = dict(gm.named_modules())[inline_mod_name] 85 assert isinstance(inline_mod, torch.fx.GraphModule) 86 call_mod_node_to_replace = None 87 for node in gm.graph.nodes: 88 if node.op == "call_module" and node.target == inline_mod_name: 89 call_mod_node_to_replace = node 90 break 91 assert call_mod_node_to_replace is not None 92 93 # Now actually do the swap. Note that we have to keep track of new nodes that are 94 # copied into `gm` -- we do this via replacement_mapping. 95 call_mod_args = call_mod_node_to_replace.args 96 replacement_mapping: Dict[torch.fx.Node, torch.fx.Node] = {} 97 ph_count = 0 98 99 def replacement_fn(node): 100 new_node = replacement_mapping[node] 101 new_node.meta = node.meta.copy() 102 return new_node 103 104 for inline_node in inline_mod.graph.nodes: 105 if inline_node.op == "placeholder": 106 replacement_mapping[inline_node] = call_mod_args[ph_count] 107 ph_count += 1 108 continue 109 110 if inline_node.op == "output": 111 outputs = inline_node.args[0] 112 output_replacements = map_arg(outputs, replacement_fn) 113 call_mod_node_to_replace.replace_all_uses_with(output_replacements) 114 continue 115 116 with gm.graph.inserting_before(call_mod_node_to_replace): 117 new_node = gm.graph.node_copy(inline_node, replacement_fn) 118 replacement_mapping[inline_node] = new_node 119 120 gm.graph.eliminate_dead_code() 121 122 123def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: 124 """ 125 Make sure the name is unique (in a module) and can represents an attr. 126 """ 127 # Delete all characters that are illegal in a Python identifier. 128 name = re.sub("[^0-9a-zA-Z_]+", "_", name) 129 if name[0].isdigit(): 130 name = f"_{name}" 131 # Now make sure it is in fact unique to the module by incrementing suffix value. 132 while hasattr(mod_traced, name): 133 match = re.match(r"(.*)_(\d+)$", name) 134 if match is None: 135 name = name + "_1" 136 else: 137 base, num = match.group(1, 2) 138 name = f"{base}_{int(num) + 1}" 139 140 return name 141 142 143def split_const_subgraphs( 144 module: Union[torch.nn.Module, torch.fx.GraphModule], 145 skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, 146 device_for_folded_attrs: str = "cpu", 147) -> FoldedGraphModule: 148 """ 149 Looks through `module` for any nodes that have all constant attribute inputs 150 and separates them out into their own constant subgraph, and returns a 151 FoldedGraphModule which runs that constant subgraph on the first run to set 152 attributes on the module prior to running the non-constant portion of the 153 graph. 154 """ 155 if not isinstance(module, torch.fx.GraphModule): 156 mod_traced = torch.fx.symbolic_trace(module) 157 else: 158 mod_traced = module 159 160 # Build up a list of const_nodes, defined as nodes that are themselves 161 # get_attrs, or have all get_attr or other constant node inputs. 162 const_nodes: Set[torch.fx.Node] = set() 163 found_const_folding = False 164 for node in mod_traced.graph.nodes: 165 # Skip over placeholders/outputs because they can't be const folded and 166 # we don't want to add tags to them. 167 if node.op in {"placeholder", "output"}: 168 continue 169 170 # If the node itself is constant, or all of its inputs are constant, 171 # then tag it as constant. 172 if node.op != "get_attr" and not set(node.all_input_nodes).issubset( 173 const_nodes 174 ): 175 continue 176 177 # If provided skip folding function says to skip, then skip. 178 if skip_folding_node_fn and skip_folding_node_fn(node): 179 continue 180 181 # Skip folding side-effectful functions 182 if node.is_impure(): 183 continue 184 185 # Must be a constant foldable node at this point. 186 const_nodes.add(node) 187 if node.op != "get_attr": 188 found_const_folding = True 189 190 # If we did not find any const folding then return early without a const fold subgraph. 191 if not found_const_folding: 192 return FoldedGraphModule(mod_traced, mod_traced.graph) 193 194 # Partition the module into two: submod_0 for constant folding subgraph, and 195 # submod_1 for the rest. 196 def mod_partition(node: torch.fx.Node): 197 return 0 if node in const_nodes else 1 198 199 split = split_module(mod_traced, module, mod_partition) 200 201 const_mod_name, non_const_mod_name = "submod_0", "submod_1" 202 # Safely get submod_1 in case there are no non-const nodes 203 const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None) 204 205 # The module that a call_module node refers to gets copied to submodules during split. 206 # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to 207 # attach inlined modules to `split` as it's the owning module now. 208 for node in non_const_gm.graph.nodes if non_const_gm else []: 209 if node.op == "call_module": 210 setattr(split, node.target, getattr(non_const_gm, node.target)) 211 for node in const_gm.graph.nodes: 212 if node.op == "call_module": 213 setattr(split, node.target, getattr(const_gm, node.target)) 214 215 # split_module currently does not use get_attrs for attrs. Instead it passes 216 # them in as args from the parent module, which used get_attrs. Here we set 217 # them as get_attrs inside const_gm, allowing for running folding without 218 # somehow a priori knowing the attrs that should be passed as args. We can 219 # unconditionally do this for all placeholders because we know all 220 # placeholders to const_gm must be constants accessible via get_attr. 221 call_const_gm_args = None 222 for node in split.graph.nodes: 223 if node.op == "call_module": 224 if node.target == const_mod_name: 225 call_const_gm_args = node.args 226 break 227 assert call_const_gm_args is not None 228 229 # Here we do the actual replacement of placeholders to get_attrs. Note that here we 230 # set the const_gm.graph into a new root_const_gm with split as the root module, 231 # because we are fetching attributes directly from the root module, instead of 232 # fetching them from const_gm. Example: The const_gm must have some format like: 233 # graph(): 234 # %inp : [num_users=1] = placeholder[target=const_inp] 235 # %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) 236 # return add 237 # We replace that with the following, which does not have any placeholders: 238 # graph(): 239 # %inp_1 : [num_users=1] = get_attr[target=const_inp] 240 # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) 241 # return add 242 root_const_gm = torch.fx.GraphModule(split, const_gm.graph) 243 for node in root_const_gm.graph.nodes: 244 if node.op == "output": 245 multiple_outputs = isinstance(node.args[0], tuple) 246 continue 247 if node.op != "placeholder": 248 continue 249 in_node = next(n for n in call_const_gm_args if n.name == node.target) 250 assert in_node.op == "get_attr" 251 with root_const_gm.graph.inserting_before(node): 252 new_node = root_const_gm.graph.get_attr(in_node.target) 253 new_node.meta = node.meta.copy() 254 node.replace_all_uses_with(new_node) 255 root_const_gm.graph.erase_node(node) 256 assert "multiple_outputs" in locals() 257 258 # Now find the call to const_gm inside split, and replace it with a getattr to the 259 # folded tensor(s) that result from constant folding. Note that we don't need to 260 # worry about whether this is one or more tensors because the original graph 261 # correctly uses getitem to extract individual tensors if there are multiple folded. 262 fx_const_folded_attrs_name = get_unique_attr_name_in_module( 263 mod_traced, "_FX_CONST_FOLDED_ATTRS" 264 ) 265 setattr( 266 split, 267 fx_const_folded_attrs_name, 268 torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined] 269 ) 270 for node in split.graph.nodes: 271 if node.op == "call_module" and node.target == const_mod_name: 272 with node.graph.inserting_before(node): 273 folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) 274 folded_attrs.meta = node.meta.copy() 275 node.replace_all_uses_with(folded_attrs) 276 break 277 278 # Finally, inline the non-constant submod (if it exists) into the split submod. 279 # This is so that the original caller who may have passed in a graph module will 280 # get back out a graph module whose graph is traced to the same granularity. 281 if hasattr(split, non_const_mod_name): 282 _inline_module(split, non_const_mod_name) 283 284 split.graph.eliminate_dead_code() 285 286 return FoldedGraphModule( 287 split, 288 split.graph, 289 root_const_gm.graph, 290 fx_const_folded_attrs_name, 291 device_for_folded_attrs, 292 ) 293