1# mypy: allow-untyped-defs 2import inspect 3from typing import Any, Callable, Dict, List, Optional, Set 4from collections import OrderedDict 5import logging 6 7import torch 8from torch.fx._compatibility import compatibility 9from torch.fx.graph_module import GraphModule 10from torch.fx.node import Node 11from torch.fx._utils import lazy_format_graph_code 12 13 14__all__ = ["Partition", "split_module"] 15log = _LOGGER = logging.getLogger(__name__) 16 17@compatibility(is_backward_compatible=True) 18class Partition: 19 def __init__(self, name: str): 20 self.name: str = name 21 self.submod_name = f"submod_{name}" 22 self.node_names: List[str] = [] 23 self.inputs: Dict[str, None] = {} 24 self.outputs: Dict[str, None] = {} 25 self.dependencies: Dict[str, None] = {} 26 self.dependents: Dict[str, None] = {} 27 self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() 28 self.environment: Dict[Node, Node] = {} 29 self.targets: Dict[str, Any] = {} 30 31 def __repr__(self) -> str: 32 return ( 33 f"name: {self.name},\n" 34 f" nodes: {self.node_names},\n" 35 f" inputs: {self.inputs},\n" 36 f" outputs: {self.outputs},\n" 37 f" partitions depended on: {self.dependencies},\n" 38 f" partition dependents: {self.dependents}" 39 ) 40 41 42# Creates subgraphs out of main graph 43@compatibility(is_backward_compatible=True) 44def split_module( 45 m: GraphModule, 46 root_m: torch.nn.Module, 47 split_callback: Callable[[Node], int], 48 qualname_map: Optional[Dict[str, str]] = None, 49 keep_original_order: Optional[bool] = False, 50 keep_original_node_name: Optional[bool] = False, 51): 52 """ 53 Creates subgraphs out of main graph 54 55 Args: 56 m (GraphModule): Graph module to split 57 root_m (torch.nn.Module): root nn module. Not currently used. Included 58 because the root nn module is usually transformed via 59 torch.fx._symbolic_trace.symbolic_trace (see example below) 60 split_callback (Callable[[Node], int]): Callable function 61 that maps a given Node instance to a numeric partition identifier. 62 split_module will use this function as the policy for which operations 63 appear in which partitions in the output Module. 64 qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a 65 mapping from new target names in the module after split to old target 66 names in the original module. 67 keep_original_order: Optional[bool]: keep the original order of the GraphModule 68 or use the Topological order of the new constructed GraphModule 69 70 71 Returns: 72 GraphModule: the module after split. 73 74 Example: 75 76 This is a sample setup: 77 78 import torch 79 from torch.fx.symbolic_trace import symbolic_trace 80 from torch.fx.graph_module import GraphModule 81 from torch.fx.node import Node 82 from torch.fx.passes.split_module import split_module 83 84 class MyModule(torch.nn.Module): 85 def __init__(self) -> None: 86 super().__init__() 87 self.param = torch.nn.Parameter(torch.rand(3, 4)) 88 self.linear = torch.nn.Linear(4, 5) 89 90 def forward(self, x, y): 91 z = self.linear(x + self.param).clamp(min=0.0, max=1.0) 92 w = self.linear(y).clamp(min=0.0, max=1.0) 93 return z + w 94 95 # symbolically trace model 96 my_module = MyModule() 97 my_module_traced = symbolic_trace(my_module) 98 99 # random mod partitioning 100 partition_counter = 0 101 NPARTITIONS = 3 102 103 def mod_partition(node: Node): 104 global partition_counter 105 partition = partition_counter % NPARTITIONS 106 partition_counter = (partition_counter + 1) % NPARTITIONS 107 return partition 108 109 # split module in module with submodules 110 module_with_submodules = split_module( 111 my_module_traced, my_module, mod_partition 112 ) 113 114 Output looks like this. Original graph is broken into partitions 115 116 > print(module_with_submodules) 117 GraphModule( 118 (submod_0): GraphModule( 119 (linear): Linear(in_features=4, out_features=5, bias=True) 120 ) 121 (submod_1): GraphModule( 122 (linear): Linear(in_features=4, out_features=5, bias=True) 123 ) 124 (submod_2): GraphModule() 125 ) 126 127 def forward(self, x, y): 128 param = self.param 129 submod_0 = self.submod_0(x, param, y); x = param = y = None 130 getitem = submod_0[0] 131 getitem_1 = submod_0[1]; submod_0 = None 132 submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None 133 getitem_2 = submod_1[0] 134 getitem_3 = submod_1[1]; submod_1 = None 135 submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None 136 return submod_2 137 138 Output of split module is the same as output of input traced module. 139 This is an example within a test setting: 140 141 > orig_out = my_module_traced(x, y) 142 > submodules_out = module_with_submodules(x, y) 143 > self.assertEqual(orig_out, submodules_out) 144 True 145 """ 146 147 log.debug( 148 "%s", 149 lazy_format_graph_code( 150 "pre split_module", m, colored=True 151 ), 152 ) 153 154 def construct_graph( 155 node: Node, 156 base_mod_env: Dict[str, Node], 157 base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule], 158 ): 159 if node.op == "placeholder": 160 default_value = ( 161 node.args[0] if len(node.args) > 0 else inspect.Signature.empty 162 ) 163 if keep_original_node_name: 164 args = () if default_value is inspect.Signature.empty else (default_value,) 165 base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) # type: ignore[arg-type] 166 else: 167 base_mod_env[node.name] = base_mod_graph.placeholder( 168 node.target, type_expr=node.type, default_value=default_value # type: ignore[arg-type] 169 ) 170 base_mod_env[node.name].meta = node.meta.copy() 171 elif node.op == "get_attr": 172 base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type] 173 base_mod_env[node.name].meta = node.meta.copy() 174 attr_val = m 175 for atom in node.target.split("."): # type: ignore[union-attr] 176 if not hasattr(attr_val, atom): 177 raise AttributeError(f"Node target {node.target} not found!") 178 attr_val = getattr(attr_val, atom) 179 base_mod_attrs[node.target] = attr_val # type: ignore[index] 180 return base_mod_env, base_mod_attrs 181 182 import sympy 183 184 partitions: Dict[str, Partition] = {} 185 orig_nodes: Dict[str, Node] = {} 186 symbol_to_node: Dict[sympy.Symbol, Node] = {} 187 188 def record_cross_partition_use( 189 def_node: Node, use_node: Optional[Node] 190 ): # noqa: B950 191 from torch.fx.experimental.symbolic_shapes import free_symbols 192 193 defined = getattr(def_node, "_fx_partition", None) 194 used = getattr(use_node, "_fx_partition", None) 195 196 log.debug( 197 "record_cross_partition_use %s (%s) %s (%s)", 198 def_node.name, defined, use_node.name if use_node is not None else "-", used 199 ) 200 201 if defined != used: 202 if defined is not None: 203 def_partition = partitions[defined] 204 def_partition.outputs.setdefault(def_node.name) 205 if used is not None: 206 def_partition.dependents.setdefault(used) 207 208 if used is not None: 209 use_partition = partitions[used] 210 use_partition.inputs.setdefault(def_node.name) 211 # We have made def_node an input to the use_partition. If 212 # this input has symbolic symbols in its size, those also must 213 # be made as inputs to the partition 214 if (def_val := def_node.meta.get("example_value")) is not None: 215 for s in sorted(free_symbols(def_val), key=str): 216 s_node = symbol_to_node[s] 217 use_partition.inputs.setdefault(s_node.name) 218 if symbol_to_node[s].op != "placeholder": 219 # If the node that defines the symbol is not a 220 # placeholder, we must make it an output of the 221 # partition. Note that this may be in a different 222 # partition than defined! Although, this doesn't 223 # really make a difference for correctness, since 224 # defined is guaranteed to have the symbol in 225 # scope and can return it; you just get less 226 # optimal codegen in this case. 227 s_defined = getattr(s_node, "_fx_partition", None) 228 if s_defined is not None: 229 s_def_partition = partitions[s_defined] 230 s_def_partition.outputs.setdefault(s_node.name) 231 s_def_partition.dependents.setdefault(used) 232 if defined is not None: 233 use_partition.dependencies.setdefault(defined) 234 235 def instantiate_node_partition_mapping(node): 236 partition_name = str(split_callback(node)) 237 log.debug("instantiate_node_partition_mapping %s (%s)", node.name, partition_name) 238 239 # add node to partitions 240 partition = partitions.get(partition_name) 241 if partition is None: 242 partitions[partition_name] = partition = Partition(partition_name) 243 244 partition.node_names.append(node.name) 245 node._fx_partition = partition_name 246 247 # Global State Nodes are nodes which by their global state effects, 248 # "taint" all downstream nodes while they are active. 249 GLOBAL_STATE_NODES = [ 250 torch.amp._enter_autocast, 251 torch.amp._exit_autocast, 252 torch._C._set_grad_enabled 253 ] 254 255 # For grad regions: 256 # ------------------------ 257 # 1. first region: we do nothing 258 # 2. subsequent regions: we insert the set_grad at the beginning 259 grad_regions: OrderedDict[Node, Set[int]] = OrderedDict() 260 261 # For autocast regions: 262 # ------------------------ 263 # 1. first region: we will only insert the _exit at the end 264 # 2. intermediate regions: we will insert both the 265 # _enter at the beginning and _exit at the end 266 # 3. last region: we will only insert _enter at the beginning 267 # We will do so in the order in which the autocasts were instantiated. 268 autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict() 269 autocast_exits: Dict[Node, Optional[Node]] = {} 270 271 active_grad = None 272 active_autocasts = set() 273 274 for node in m.graph.nodes: 275 # This will prefer placeholder bindings, because those come first. 276 # This is a little dangerous though: it is possible that an unbacked 277 # symbol is used without any binding site for it, in which case we 278 # will get a KeyError not able to find it. I'd like to fix this by 279 # having passes.runtime_assert establish some invariants that I can 280 # rely on later, but this needs some extra work. Quick fix first. 281 # See https://github.com/pytorch/pytorch/issues/130534 282 if ( 283 (val := node.meta.get("example_value")) is not None and 284 isinstance(val, torch.SymInt) and 285 isinstance(s0 := val.node.expr, sympy.Symbol) and 286 s0 not in symbol_to_node 287 ): 288 symbol_to_node[val.node.expr] = node 289 290 if node.op in ["placeholder", "get_attr", "output"]: 291 continue 292 293 instantiate_node_partition_mapping(node) 294 295 if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: 296 if node.target == torch._C._set_grad_enabled: 297 assert len(node.args) == 1 298 assert isinstance(node.args[0], bool) 299 active_grad = node 300 grad_regions[active_grad] = set({split_callback(node)}) 301 elif node.target == torch.amp._enter_autocast: 302 # Should all be python constants 303 assert all(not isinstance(arg, Node) for arg in node.args) 304 active_autocasts.add(node) 305 autocast_regions[node] = set({split_callback(node)}) 306 autocast_exits[node] = None 307 elif node.target == torch.amp._exit_autocast: 308 assert len(node.args) == 1 309 autocast_regions[node.args[0]].add(split_callback(node)) 310 active_autocasts.remove(node.args[0]) 311 autocast_exits[node.args[0]] = node 312 313 if active_grad is not None: 314 grad_regions[active_grad].add(split_callback(node)) 315 316 for a in active_autocasts: 317 autocast_regions[a].add(split_callback(node)) 318 319 assert all(v is not None for v in autocast_exits.values()), "autocast must exit" 320 321 autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} 322 grad_regions = {k: sorted(v) for k, v in grad_regions.items()} 323 324 if _LOGGER.isEnabledFor(logging.DEBUG): 325 _LOGGER.debug("autocast_regions: %s", autocast_regions) 326 _LOGGER.debug("grad_regions: %s", grad_regions) 327 328 assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions) 329 330 # split nodes into partitions 331 highest_partition = -1 332 for node in m.graph.nodes: 333 orig_nodes[node.name] = node 334 335 # TODO currently placeholders/parameters aren't put into random partitions, 336 # rather they're added to the graphs where they are used down below 337 if node.op in ["placeholder", "get_attr"]: 338 continue 339 if node.op == "output": 340 torch.fx.graph.map_arg( 341 node.args[0], lambda n: record_cross_partition_use(n, None) 342 ) 343 continue 344 345 if assert_monotonically_increasing: 346 pid = split_callback(node) 347 assert highest_partition <= pid, \ 348 ("autocast or set_grad_enabled require monotonically increasing partitions:" 349 f"highest: {highest_partition}, this node's: {pid}") 350 highest_partition = pid 351 352 # do not capture cross-partition dependencies for global state nodes as they will be 353 # self-contained - their setup and unwind will be isolated to each partition submodule. 354 if node.target not in GLOBAL_STATE_NODES: 355 torch.fx.graph.map_arg( 356 node.args, lambda def_node: record_cross_partition_use(def_node, node) 357 ) 358 torch.fx.graph.map_arg( 359 node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) 360 ) # noqa: B950 361 362 original_partition_order = list(partitions.keys()) 363 # find partitions with no dependencies 364 root_partitions: List[str] = [] 365 for partition_name, partition in partitions.items(): 366 if not len(partition.dependencies): 367 root_partitions.append(partition_name) 368 369 # check partitions for circular dependencies and create topological partition ordering 370 sorted_partitions: List[str] = [] 371 while root_partitions: 372 root_partition = root_partitions.pop() 373 sorted_partitions.append(root_partition) 374 for dependent in partitions[root_partition].dependents: 375 partitions[dependent].dependencies.pop(root_partition) 376 if not partitions[dependent].dependencies: 377 root_partitions.append(dependent) 378 if len(sorted_partitions) != len(partitions): 379 raise RuntimeError("cycle exists between partitions!") 380 381 # Enter prelude 382 for regions_mapping in [autocast_regions, grad_regions]: 383 for node, regions in regions_mapping.items(): 384 assert len(regions) > 0 385 partitions[str(regions[0])].environment[node] = node 386 for r in regions[1:]: 387 partition = partitions[str(r)] 388 new_node = partition.graph.create_node( 389 op=node.op, 390 target=node.target, 391 args=tuple(arg for arg in node.args), 392 kwargs={}, 393 type_expr=node.type, 394 ) 395 new_node.meta = node.meta.copy() # is it really a good idea to copy this? 396 partition.environment[node] = new_node 397 398 # add placeholders to partition inputs 399 for partition_name in sorted_partitions: 400 partition = partitions[partition_name] 401 for inp in partition.inputs: 402 placeholder = partition.graph.placeholder( 403 inp, 404 type_expr=orig_nodes[inp].type, 405 ) 406 placeholder.meta = orig_nodes[inp].meta.copy() 407 partition.environment[orig_nodes[inp]] = placeholder 408 409 # Transform nodes and collect targets for partition's submodule 410 for node in m.graph.nodes: 411 if hasattr(node, "_fx_partition"): 412 partition = partitions[node._fx_partition] 413 414 # swap out old graph nodes in kw/args with references to new nodes in this submodule 415 environment = partition.environment 416 gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) 417 gathered_kwargs = torch.fx.graph.map_arg( 418 node.kwargs, lambda n: environment[n] 419 ) 420 421 if node.op not in ["call_module", "get_attr"]: 422 target = node.target 423 else: 424 target_atoms = node.target.split(".") 425 target_attr = m 426 for atom in target_atoms: 427 if not hasattr(target_attr, atom): 428 raise AttributeError(f"Operator target {node.target} not found!") 429 target_attr = getattr(target_attr, atom) 430 # target = target_atoms[-1] 431 target = "_".join(target_atoms) 432 partition.targets[target] = target_attr 433 # Fill in the passed-in mapping from new qualname to old qualname 434 if qualname_map is not None: 435 # When creating the split module later, the submodules will have 436 # path prefix matching the corresponding partition's submod_name 437 qualname = f"{partition.submod_name}.{target}" 438 qualname_map[qualname] = node.target 439 440 assert isinstance(gathered_args, tuple) 441 assert isinstance(gathered_kwargs, dict) 442 name = node.name if keep_original_node_name else None 443 new_node = partition.graph.create_node( 444 op=node.op, 445 target=target, 446 args=gathered_args, 447 kwargs=gathered_kwargs, 448 type_expr=node.type, 449 name=name, 450 ) 451 new_node.meta = node.meta.copy() 452 partition.environment[node] = new_node 453 454 # Exit epilogue 455 for regions_mapping in [autocast_regions]: 456 for node in reversed(regions_mapping): 457 regions = regions_mapping[node] 458 assert len(regions) > 0 459 for r in regions[:-1]: 460 partition = partitions[str(r)] 461 exit_node = autocast_exits[node] 462 assert exit_node is not None, "Missing exit node" 463 new_node = partition.graph.create_node( 464 op=exit_node.op, 465 target=exit_node.target, 466 args=(partition.environment[node],), 467 kwargs={}, 468 type_expr=exit_node.type, 469 ) 470 new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this? 471 472 # original module environment dict mapping node names to nodes 473 orig_mod_env: Dict[str, Node] = {} 474 # Set up values to construct base module 475 base_mod_env: Dict[str, Node] = {} 476 base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() 477 base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} 478 if not keep_original_order: 479 for node in m.graph.nodes: 480 base_mod_env, base_mod_attrs = construct_graph( 481 node, base_mod_env, base_mod_attrs 482 ) 483 484 else: 485 # Go through the graph to construct the mapping dict 486 for node in m.graph.nodes: 487 orig_mod_env[node.name] = node 488 489 # Do some things iterating over the partitions in topological order again: 490 # 1) Finish off submodule Graphs by setting corresponding outputs 491 # 2) Construct GraphModules for each submodule 492 # 3) Construct the base graph by emitting calls to those submodules in 493 # topological order or original order specified by keep_original_order 494 495 construct_order_partitions = ( 496 sorted_partitions if not keep_original_order else original_partition_order 497 ) 498 499 already_constructed_attr_nodes = set() 500 501 # We actually need to insert the placeholder nodes in the original order 502 # otherwise graph signature will be wrong. 503 original_order = [node for node in m.graph.nodes if node.op == "placeholder"] 504 505 for partition_name in construct_order_partitions: 506 partition = partitions[partition_name] 507 508 # Set correct output values 509 output_vals = tuple( 510 partition.environment[orig_nodes[name]] for name in partition.outputs 511 ) 512 513 # skip output node generation if there are no output values 514 num_output_vals = len(output_vals) 515 if num_output_vals == 1: 516 partition.graph.output(output_vals[0]) 517 elif num_output_vals > 1: 518 partition.graph.output(output_vals) 519 520 if keep_original_order: 521 # first get the attr nodes required by this partition 522 orig_mod_attr_nodes: List[Node] = [ 523 orig_mod_env[key] for key in partition.inputs if key not in original_order 524 ] 525 526 for node in original_order: 527 if node in already_constructed_attr_nodes: 528 continue # already added this attr to the base graph 529 base_mod_env, based_mod_attrs = construct_graph( 530 node, base_mod_env, base_mod_attrs 531 ) 532 already_constructed_attr_nodes.add(node) 533 534 # Construct GraphModule for this partition 535 for node in orig_mod_attr_nodes: # type: ignore[attr-defined] 536 if node in already_constructed_attr_nodes: 537 continue 538 base_mod_env, base_mod_attrs = construct_graph( 539 node, base_mod_env, base_mod_attrs 540 ) 541 already_constructed_attr_nodes.add(node) 542 543 base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule( 544 partition.targets, partition.graph 545 ) # noqa: B950 546 547 # Emit call in base graph to this submodule 548 output_val = base_mod_graph.call_module( 549 partition.submod_name, 550 tuple(base_mod_env[name] for name in partition.inputs), 551 ) 552 553 num_outputs = len(partition.outputs) 554 if num_outputs > 1: 555 # Unpack multiple return values from submodule 556 output_val_proxy = torch.fx.proxy.Proxy(output_val) 557 for i, output_name in enumerate(partition.outputs): 558 base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] 559 elif num_outputs == 1: 560 base_mod_env[next(iter(partition.outputs))] = output_val 561 562 for node in m.graph.nodes: 563 if node.op == "output": 564 base_mod_graph.output( 565 torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) 566 ) # noqa: B950 567 568 ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) 569 log.debug( 570 "%s", 571 lazy_format_graph_code( 572 "post split_module", ret, colored=True 573 ), 574 ) 575 return ret 576