1# mypy: allow-untyped-defs 2import operator 3from collections import deque 4from typing import Dict, List, Set, NamedTuple, Tuple, Deque 5 6import torch 7from torch.fx.passes.graph_manipulation import get_size_of_all_nodes 8from torch.fx.experimental.partitioner_utils import ( 9 Partition, 10 Device, 11 PartitionerConfig, 12 get_partition_to_latency_mapping, 13 get_latency_of_partitioned_graph, 14 NodeLatency, 15 get_extra_size_of, 16 PartitionMode, 17) 18from torch.fx.graph_module import GraphModule 19from torch.fx.node import Node, map_arg 20from torch.fx.passes.split_module import split_module 21 22 23class DAGNode: 24 """DAGNode class maintains useful information for a partition (submodule), 25 and its input submodules and output submodules. 26 """ 27 28 def __init__( 29 self, 30 submodule_node: Node, 31 input_nodes: List[Node], 32 output_nodes: List[Node], 33 logical_device_ids: List[int], 34 size_bytes: int, 35 ) -> None: 36 self.submodule_node: Node = submodule_node 37 self.input_nodes: List[Node] = input_nodes 38 self.output_nodes: List[Node] = output_nodes 39 self.logical_device_ids: List[int] = logical_device_ids 40 self.size_bytes = size_bytes 41 42 def __str__(self) -> str: 43 return str(self.submodule_node) 44 45 46class DAG: 47 """DAG class contains all the DAG nodes""" 48 49 def __init__(self) -> None: 50 self.nodes: List[DAGNode] = [] 51 52 def create_node( 53 self, 54 submodule_node: Node, 55 input_nodes: List[Node], 56 output_nodes: List[Node], 57 logical_devices: List[int], 58 size_bytes: int, 59 ) -> None: 60 node = DAGNode( 61 submodule_node, input_nodes, output_nodes, logical_devices, size_bytes 62 ) 63 self.nodes.append(node) 64 65 66class PartitionResult(NamedTuple): 67 """NameTuple used for returning DAG and a new fx module""" 68 69 dag: DAG 70 module_with_submodules: GraphModule 71 72 73"""Followings are some helper functions for partition manipulation""" 74 75 76def reset_partition_device(partitions): 77 for partition in partitions: 78 partition.logical_device_ids = [] 79 80 81def combine_two_partitions( 82 partition_0: Partition, partition_1: Partition, partitions: List[Partition] 83) -> None: 84 """Given a list of partitions and its two partitions, 85 combine these two partitions into a new one appending to the partitions 86 and remove the previous two partitions from the list of partitions 87 """ 88 partition = Partition(len(partitions)) 89 partition.nodes = partition_0.nodes.union(partition_1.nodes) 90 partition.recalculate_mem_size() 91 partitions.append(partition) 92 partitions.remove(partition_0) 93 partitions.remove(partition_1) 94 reorganize_partitions(partitions) 95 return 96 97 98def set_parents_and_children(partitions: List[Partition]) -> None: 99 """Given a list of partitions, mark parents and children for each partition""" 100 # Go through all nodes in a partition. 101 # If a node's user is in other partition, 102 # then the other partition is this partition's children. 103 # This partition is the other partition's parent 104 for partition in partitions: 105 partition.children = set() 106 partition.parents = set() 107 for partition in partitions: 108 for node in partition.nodes: 109 # For each node in the current partition, find its users 110 users = node.users 111 for n in users: 112 # Find which the partition the user node belongs to. 113 # Note that if the node itself is also belongs to that partition, 114 # that partition is not the child of the current partition 115 for p in partitions: 116 if p != partition and n in p.nodes and node not in p.nodes: 117 partition.children.add(p) 118 p.parents.add(partition) 119 return 120 121 122def reorganize_partitions(partitions: List[Partition]) -> None: 123 """Given a list of partitions, reorganize partition id, 124 its parents and its children for each partition 125 """ 126 # Rearrange partition ids 127 for i, partition in enumerate(partitions): 128 partition.partition_id = i 129 set_parents_and_children(partitions) 130 return 131 132 133def get_bfs_level_partition(partitions: List[Partition]) -> None: 134 """Given a list of partitions, 135 mark the bfs level for each partition 136 """ 137 current_level: Set[Partition] = set() 138 visited: Set[Partition] = set() 139 for partition in partitions: 140 # If a partition has no parent, it should be in root level 141 if len(partition.parents) == 0: 142 current_level.add(partition) 143 next_level: Set[Partition] = set() 144 level = 0 145 # bfs 146 while current_level: 147 partition = current_level.pop() 148 partition.bfs_level = level 149 visited.add(partition) 150 children = partition.children 151 for child in children: 152 if child not in next_level: 153 next_level.add(child) 154 if not current_level: 155 current_level = next_level.copy() 156 next_level = set() 157 level += 1 158 return 159 160 161def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]: 162 """Given a list of partitions,return node to partition mapping""" 163 node_to_partition: Dict[Node, int] = {} 164 for partition in partitions: 165 for node in partition.nodes: 166 node_to_partition[node] = partition.partition_id 167 return node_to_partition 168 169 170def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]: 171 """Get a mapping from device logical ID to Device object.""" 172 logical_id_to_device: Dict[int, Device] = {} 173 for d in devices: 174 logical_id_to_device[d.logical_id] = d 175 return logical_id_to_device 176 177 178def get_device_partition_stats( 179 partitions: List[Partition], devices: List[Device] 180) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]: 181 """Given a list of partitions and a list of devices, returns: 182 1. A mapping from device to partitions on it; 183 2. A mapping from device to its remaining memory size; 184 3. A list of partitions that do not have a device. 185 """ 186 # logical id to device 187 logical_id_to_device = get_logical_id_to_device(devices) 188 # Track partitions on device 189 device_to_partitions: Dict[Device, List[Partition]] = {} 190 # Track device's left mem size 191 device_to_left_mem_bytes: Dict[Device, int] = {} 192 for d in devices: 193 device_to_partitions[d] = [] 194 device_to_left_mem_bytes[d] = d.available_mem_bytes 195 196 # Deal with the partitions that already have a device 197 # and also collect all partitions without a device (no_device_partitions) 198 no_device_partitions = [] 199 for partition in partitions: 200 if partition.logical_device_ids != []: 201 for logical_id in partition.logical_device_ids: 202 device = logical_id_to_device[logical_id] 203 device_to_partitions[device].append(partition) 204 device_to_left_mem_bytes[device] -= partition.used_mem_bytes 205 else: 206 no_device_partitions.append(partition) 207 208 return ( 209 device_to_partitions, 210 device_to_left_mem_bytes, 211 no_device_partitions, 212 ) 213 214 215def get_device_to_partitions_mapping( 216 partitions: List[Partition], devices: List[Device] 217): 218 """Given a list of partitions and a list of devices, 219 map each partition into a device. 220 """ 221 222 def calculate_extra_mem_bytes_needed_for( 223 partition: Partition, partitions: List[Partition] 224 ): 225 all_nodes: Set[Node] = set() 226 for p in partitions: 227 all_nodes = all_nodes.union(p.nodes) 228 if len(all_nodes) == 0: 229 return partition.used_mem_bytes 230 all_nodes = all_nodes.union(partition.nodes) 231 extra_size_needed = 0 232 for node in partition.nodes: 233 extra_size_needed += get_extra_size_of(node, all_nodes) 234 return extra_size_needed 235 236 def find_device_for(partition: Partition): 237 """Given a partition, find a logical device for the partition 238 The algorithm is to put the partition on the device 239 that has just enough mem left for that partition. 240 device_to_left_mem_bytes is a dictionary between device and its left mem size 241 sorted by its left mem size 242 """ 243 for d in device_to_left_mem_bytes: 244 extra_size_needed = calculate_extra_mem_bytes_needed_for( 245 partition, device_to_partitions[d] 246 ) 247 if extra_size_needed < device_to_left_mem_bytes[d]: 248 device_to_partitions[d].append(partition) 249 partition.logical_device_ids.append(d.logical_id) 250 device_to_left_mem_bytes[d] -= extra_size_needed 251 return True 252 return False 253 254 ( 255 device_to_partitions, 256 device_to_left_mem_bytes, 257 no_device_partitions, 258 ) = get_device_partition_stats(partitions, devices) 259 260 # Find devices for all the partitions without a device 261 found_device = True 262 for partition in no_device_partitions: 263 device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1))) 264 found_device = find_device_for(partition) 265 if not found_device: 266 break 267 return found_device 268 269 270def check_dependency(partition): 271 """Given a partition,check if there is a circular dependency on 272 this partition using bfs 273 """ 274 visited: Set[Partition] = {partition} 275 queue: Deque[Partition] = deque([partition]) 276 while queue: 277 p = queue.popleft() 278 for child in p.children: 279 if child == partition: 280 return True 281 else: 282 if child not in visited: 283 visited.add(child) 284 queue.append(child) 285 return False 286 287 288class Partitioner: 289 """A fx module may not fit into one device. 290 Partitioner class helps partition one fx module into submodules (partitions), 291 so that the submodules can be executed crossing different accelerators. 292 The main function of this class is self.partition_graph. 293 It partitions the fx module based on the scheme specified in partition_config 294 A DAG structure is returned 295 along with a new fx module with submodule nodes. 296 """ 297 298 def __init__(self) -> None: 299 self.partitions: List[Partition] = [] 300 self.node_to_partition: Dict[Node, int] = {} 301 self.devices: List[Device] = [] 302 303 def partition_graph( 304 self, 305 fx_module: GraphModule, 306 torch_module: torch.nn.Module, 307 partitioner_config: PartitionerConfig, 308 ) -> PartitionResult: 309 """Given the fx module, torch module and partitioner_config, 310 find the partitions, do the partitions, 311 and then return a DAG and a new fx module with submodule nodes (partitions) 312 """ 313 self.graph_module = fx_module 314 self.torch_module = torch_module 315 self.devices = partitioner_config.devices 316 if len(self.devices) == 0: 317 raise RuntimeError("No devices") 318 # Tag the size in bytes to all nodes in the graph_module. 319 get_size_of_all_nodes(self.graph_module) 320 # Check if there are op nodes in the fx module 321 nodes = self.graph_module.graph.nodes 322 if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes): 323 raise RuntimeError("No Partition since no operations in the module") 324 # Calculate total size of the fx module 325 total_size_of_graph = 0 326 for node in nodes: 327 if node.op == "output": 328 break 329 total_size_of_graph += node.size_bytes.total_size 330 # Find the device with the max mem size 331 device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) 332 # AOT based partition 333 if partitioner_config.mode == PartitionMode.aot_based: 334 self.aot_based_partition( 335 partitioner_config.node_to_partition_mapping, 336 partitioner_config.partition_to_logical_device_mapping, 337 ) 338 # Single partition if the whole module can be fit into one device 339 elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: 340 self.find_single_partition( 341 total_size_of_graph, logical_device_id=device_with_max_mem.logical_id 342 ) 343 elif total_size_of_graph > sum(d.available_mem_bytes for d in self.devices): 344 raise RuntimeError("Devices have no enough memory for the module") 345 else: 346 # Sparse nn based partition 347 if partitioner_config.mode == PartitionMode.sparse_nn: 348 available_mem_bytes = self.devices[0].available_mem_bytes 349 if not all( 350 device.available_mem_bytes == available_mem_bytes 351 for device in self.devices 352 ): 353 raise RuntimeError("All devices must have same memory size!") 354 # sparse_nn_partition only support same memory size 355 # TODO: add different size support for sparse_nn_partition 356 self.sparse_nn_partition(available_mem_bytes) 357 # Cost aware partition 358 elif partitioner_config.mode == PartitionMode.cost_aware: 359 self.cost_aware_partition( 360 partitioner_config.transfer_rate_bytes_per_sec, 361 partitioner_config.node_to_latency_mapping, 362 ) 363 # KL based partition 364 elif partitioner_config.mode == PartitionMode.kl_based: 365 self.kl_based_partition( 366 partitioner_config.transfer_rate_bytes_per_sec, 367 partitioner_config.node_to_latency_mapping, 368 ) 369 else: 370 self.size_based_partition() 371 372 # Saturate host if possible. 373 if partitioner_config.saturate_host: 374 self.saturate_host() 375 376 # Partition the graph module based on the partition assignment. 377 module_with_submodules = self.do_partition() 378 379 # The DAG contains DAGNodes with info of each partition's input nodes, output nodes 380 # and how partitions are connected. 381 dag = self.dump_dag(module_with_submodules) 382 ret = PartitionResult(dag, module_with_submodules) 383 return ret 384 385 def find_single_partition( 386 self, total_size_of_graph, logical_device_id: int = 0 387 ) -> None: 388 """Fit the whole fx module into one device""" 389 partition_0 = self.create_partition() 390 for node in self.graph_module.graph.nodes: 391 if node.op == "output": 392 # Skip the output node, but there can 393 # be nodes after the output in certain cases. 394 continue 395 partition_0.nodes.add(node) 396 partition_0.used_mem_bytes = total_size_of_graph 397 partition_0.logical_device_ids = [logical_device_id] 398 # Get the node to partition mapping 399 self.node_to_partition = get_node_to_partition_mapping(self.partitions) 400 return 401 402 def size_based_partition(self) -> None: 403 """This method is to partition the fx module based on memory size. 404 It uses greedy approach. The result may not be the best. 405 The basic idea is: 406 Step 1: 407 Find a device which has enough memory to fit the current node, create a empty partition 408 with the size of that device. 409 Then keep adding the following nodes into the partition until the partition is full. 410 Step 2: 411 Repeat Step 1 until no device left 412 Step 3: 413 If some nodes are left, create a partition for each left node (single node partition). 414 and then try to map those partitions into logical devices with enough mem left. 415 """ 416 417 def find_device_based_on_size(node) -> Device: 418 """Given a node, this function is to find a logical device 419 that could fit the node. 420 """ 421 mem_size_needed = get_extra_size_of(node, set()) 422 device = Device("", -1, -1) 423 for d in self.devices: 424 if ( 425 d not in occupied_devices 426 and d.available_mem_bytes >= mem_size_needed 427 ): 428 device = d 429 break 430 if device.available_mem_bytes < 0: 431 raise RuntimeError(str(node) + "is too large to fit any device") 432 occupied_devices.append(device) 433 return device 434 435 # Track partition and its left mem size 436 partition_to_left_mem_bytes: Dict[Partition, int] = {} 437 # Track all the devices that have been used 438 occupied_devices: List[Device] = [] 439 partition = self.create_partition() 440 for node in self.graph_module.graph.nodes: 441 if node.op in {"call_module", "call_method", "call_function"}: 442 # Check if there are devices left 443 if len(self.partitions) <= len(self.devices): 444 total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) 445 # Check if the current partition is the very first partition 446 if partition.used_mem_bytes == 0: 447 # Find a device to fit the first node, return available mem size 448 device = find_device_based_on_size(node) 449 occupied_devices.append(device) 450 # Update partition and its left mem size 451 partition_to_left_mem_bytes[ 452 partition 453 ] = device.available_mem_bytes 454 # Update available mem for the current partition 455 partition.logical_device_ids.append(device.logical_id) 456 else: 457 # The current partition is not the first partition 458 # Check if the current node can fit into current partition 459 if ( 460 partition_to_left_mem_bytes[partition] 461 < total_size_of_input_nodes 462 ): 463 # Check if no device is left 464 if len(self.partitions) == len(self.devices): 465 # No device is left 466 # Put the previous partitions into a list (non_single_node_partitions) 467 non_single_node_partitions = self.partitions[:] 468 # Create the first single node partition for the current node 469 self.create_single_node_partition(node) 470 continue 471 # Some devices are still left 472 # Create a new partition with a mem size that is enough for the current node 473 device = find_device_based_on_size(node) 474 partition = self.create_partition() 475 total_size_of_input_nodes = get_extra_size_of( 476 node, partition.nodes 477 ) 478 partition_to_left_mem_bytes[ 479 partition 480 ] = device.available_mem_bytes 481 partition.logical_device_ids.append(device.logical_id) 482 partition.add_node(node) 483 partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes 484 # Create single node partitions if no device is left 485 else: 486 self.create_single_node_partition(node) 487 reorganize_partitions(self.partitions) 488 # Get the node to partition mapping 489 self.node_to_partition = get_node_to_partition_mapping(self.partitions) 490 # Mapping all partitions into device 491 found_partition_to_device_mapping = get_device_to_partitions_mapping( 492 self.partitions, self.devices 493 ) 494 if not found_partition_to_device_mapping: 495 raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") 496 return 497 498 def saturate_host(self) -> None: 499 """Saturate host by assigning replicates to unused devices with enough memory. 500 It uses a greedy approach to find a next available set of devices to place all split 501 partitions: For each used device, it searches for an idle device with minimal memory 502 size that can hold all the partition located on that device; If the search is successful 503 for all used devices, it then assigns the new devices' logical ID to the corresponding 504 partition. 505 """ 506 ( 507 device_to_partitions, 508 device_to_left_mem_bytes, 509 no_device_partitions, 510 ) = get_device_partition_stats(self.partitions, self.devices) 511 512 assert ( 513 len(no_device_partitions) == 0 514 ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" 515 516 # Devices that hold partitions 517 used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] 518 # Track replicates of the assigned devices 519 replicated_device_to_used_device: Dict[Device, Device] = {} 520 521 while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( 522 self.devices 523 ): 524 # Success flag for this round 525 success = True 526 # Devices that have not been assigned 527 idle_devices = [ 528 d 529 for d in self.devices 530 if d not in used_devices and d not in replicated_device_to_used_device 531 ] 532 # Temporary mapping from replicated device to original device 533 temp_replicate_mapping = {} 534 535 # Find a new device to replicate all partitions on an used device 536 for used_device in used_devices: 537 # Idle devices that have enough memory 538 available_devices = [ 539 d 540 for d in idle_devices 541 if d.available_mem_bytes 542 >= used_device.available_mem_bytes 543 - device_to_left_mem_bytes[used_device] 544 ] 545 if len(available_devices) == 0: 546 success = False 547 break 548 new_device = min(available_devices, key=lambda d: d.available_mem_bytes) 549 idle_devices.remove(new_device) 550 temp_replicate_mapping[new_device] = used_device 551 552 if not success: 553 break 554 replicated_device_to_used_device.update(temp_replicate_mapping) 555 556 # Update logical device IDs assigned to the partitions 557 for ( 558 replicate_device, 559 original_device, 560 ) in replicated_device_to_used_device.items(): 561 logical_id = replicate_device.logical_id 562 for partition in device_to_partitions[original_device]: 563 partition.logical_device_ids.append(logical_id) 564 for p in self.partitions: 565 print(p.logical_device_ids) 566 567 def do_partition(self) -> GraphModule: 568 """Return a new fx module with submodule nodes (partitions).""" 569 module_with_submodules = split_module( 570 self.graph_module, 571 self.torch_module, 572 lambda node: self.node_to_partition[node], 573 ) 574 return module_with_submodules 575 576 def dump_dag(self, module_with_submodules: GraphModule) -> DAG: 577 """Return the dag structure and the new fx module with submodules.""" 578 dag = DAG() 579 for node in module_with_submodules.graph.nodes: 580 if node.op == "output": 581 break 582 if node.op in {"placeholder", "get_attr"}: 583 continue 584 if node.target == operator.__getitem__: 585 continue 586 input_nodes: Dict[Node, None] = {} 587 map_arg(node.args, input_nodes.setdefault) 588 map_arg(node.kwargs, input_nodes.setdefault) 589 # When a node has two or more output nodes, 590 # it outputs its result to 'getitem' nodes. 591 # Those 'getitem' nodes are the output node for this node. 592 # Otherwise, the output node is this node itself. 593 if len(node.users) > 1: 594 output_nodes = list(node.users) 595 else: 596 output_nodes = [node] 597 partition_id = int(node.name.rsplit("_", 1)[-1]) 598 device_ids = self.partitions[partition_id].logical_device_ids 599 size_bytes = self.partitions[partition_id].used_mem_bytes 600 dag.create_node( 601 node, list(input_nodes), output_nodes, device_ids, size_bytes 602 ) 603 return dag 604 605 def create_partition(self) -> Partition: 606 """Create a partition and append it to self.partitions.""" 607 partition_id = len(self.partitions) 608 partition = Partition(partition_id) 609 self.partitions.append(partition) 610 return partition 611 612 def create_single_node_partition(self, node): 613 """Create a partition for a single node""" 614 partition = self.create_partition() 615 partition.add_node(node) 616 return 617 618 def sparse_nn_partition(self, available_mem_bytes: int) -> None: 619 """This method partition a sparse nn module. 620 It is size based partition but different from size_based_partition, 621 it only works when all the devices have same memory size (available_mem_bytes). 622 In the future, devices with different mem sizes will be supported like size_based_partition. 623 It first traverse all the nodes and do the partitions based on the same memory size. 624 If the current partition has no enough memory left for a new op node 625 (call_module, call_method, call_function), a new partition is created. 626 When crossing the boundary between non-embedding nodes and embedding nodes, 627 a new partition is created regardlessly. 628 For example, if the current node is a non-embedding node but the next node is an 629 embedding node, a new partition is created for the next node. 630 After the partition, the partitions are combined as much as possible. 631 The rule is that a non-embedding partition only 632 combines with another non-embedding one. 633 So as the embedding partitions. 634 """ 635 636 def combine_partitions_based_on_size( 637 partitions: List[Partition], available_mem_bytes: int 638 ) -> None: 639 """Combining small partitions together to keep as less partitions as possible. 640 Here is an example of the algorithm to do this: 641 Assume some partitions, we first sort them based on partition used memory size. 642 [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] 643 The available memory is 10. 644 step 1: self.find_partition_to_combine_based_on_size() 645 First, mark bfs level for each partition 646 Second, look the smallest partition, partition_4: 10 - 1 = 9 647 It means any partition has a used memory equal or less than 9 could combine this partition 648 We go from the largest and selection partition_0. 649 Check the bfs level for two partitions, if the level difference is less than 2, 650 it can be combined. 651 step 2: repeat step 1 until no partitions can be combined 652 """ 653 find_combination = True 654 while find_combination: 655 # Sort partitions based on memory size 656 sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) 657 # Mark bfs level 658 get_bfs_level_partition(self.partitions) 659 find_combination, partitions = find_partition_to_combine_based_on_size( 660 sorted_partitions, available_mem_bytes, partitions 661 ) 662 return 663 664 def calculate_mem_bytes_needed(p1, p2): 665 """Given two partitions, calculate how many mem bytes 666 are needed if two partitions are combined 667 """ 668 nodes = p1.nodes.union(p2.nodes) 669 mem_bytes_needed = 0 670 for node in nodes: 671 mem_bytes_needed += get_extra_size_of(node, nodes) 672 return mem_bytes_needed 673 674 def find_partition_to_combine_based_on_size( 675 sorted_partitions: List[Partition], 676 available_mem_bytes: int, 677 partitions: List[Partition], 678 ) -> Tuple[bool, List[Partition]]: 679 """step 1 in combine_partition_based_on_size()""" 680 find_combination = False 681 smallest_partition = sorted_partitions.pop(0) 682 for p in sorted_partitions[::-1]: 683 if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: 684 # Calculate how many bytes needed if combined 685 mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) 686 if mem_bytes_needed <= available_mem_bytes: 687 combine_two_partitions(p, smallest_partition, self.partitions) 688 partitions.remove(smallest_partition) 689 partitions.remove(p) 690 partitions.append(self.partitions[-1]) 691 find_combination = True 692 break 693 return find_combination, partitions 694 695 def reset_partition_in_sparse_nn(partition, new_partition=True): 696 """If crossing the boundary between non-embedding nodes and 697 embedding nodes, create a new partition 698 """ 699 if in_embedding_region: 700 embedding_partitions.append(partition) 701 else: 702 non_embedding_partitions.append(partition) 703 if new_partition: 704 partition = self.create_partition() 705 partition.left_mem_bytes = available_mem_bytes 706 return partition 707 return None 708 709 def is_embedding_node(node: Node) -> bool: 710 """Check if a node is an embedding node""" 711 if node.op == "call_module": 712 submodule = self.graph_module 713 for atom in str(node.target).split("."): 714 if not hasattr(submodule, atom): 715 raise RuntimeError( 716 f"Module {submodule} has no attribute {atom}" 717 ) 718 submodule = getattr(submodule, atom) 719 if "Embedding" in str(submodule): 720 return True 721 return False 722 723 # Track embedding partitions and non-embedding partitions separately 724 embedding_partitions: List[Partition] = [] 725 non_embedding_partitions: List[Partition] = [] 726 # A Flag to check the boundary 727 in_embedding_region: bool = False 728 partition = self.create_partition() 729 for node in self.graph_module.graph.nodes: 730 if node.op in {"call_module", "call_method", "call_function"}: 731 # Check if crossing the boundary between embedding nodes and non embedding nodes 732 if is_embedding_node(node) != in_embedding_region: 733 # Crossing the boundary 734 # Check if the current partition is an empty partition 735 if partition.used_mem_bytes != 0: 736 # The current partition isn't an empty partition. Create a new one. 737 partition = reset_partition_in_sparse_nn(partition) 738 in_embedding_region = not in_embedding_region 739 total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) 740 if ( 741 total_size_of_input_nodes + partition.used_mem_bytes 742 > available_mem_bytes 743 ): 744 partition = reset_partition_in_sparse_nn(partition) 745 total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) 746 if total_size_of_input_nodes > available_mem_bytes: 747 raise RuntimeError( 748 node.target + "is too large to fit into a device" 749 ) 750 partition.add_node(node) 751 reset_partition_in_sparse_nn(partition, new_partition=False) 752 # Set parents and children for partitions 753 set_parents_and_children(self.partitions) 754 # Combining non-embedding partitions 755 combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) 756 # Combining embedding partitions 757 combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) 758 total_size_of_non_embedding_partitions = 0 759 for partition in non_embedding_partitions: 760 total_size_of_non_embedding_partitions += partition.used_mem_bytes 761 # Check if devices are enough for all partitions 762 if len(embedding_partitions) > len(self.devices): 763 msg = ( 764 "Need " 765 + str(len(embedding_partitions)) 766 + " devices, but only " 767 + str(len(self.devices)) 768 + " provided" 769 ) 770 raise RuntimeError(msg) 771 occupied_devices = [] 772 for i, partition in enumerate(embedding_partitions): 773 # Check if all non-embedding partitions can fit into embedding partition devices 774 if ( 775 total_size_of_non_embedding_partitions + partition.used_mem_bytes 776 > available_mem_bytes 777 ): 778 raise RuntimeError( 779 "partition_" 780 + str(partition.partition_id) 781 + "(embedding partition) and non embedding partitions can not fit into one device" 782 ) 783 else: 784 # Add logical device to the partition 785 partition.logical_device_ids = [self.devices[i].logical_id] 786 occupied_devices.append(self.devices[i].logical_id) 787 # Add logical devices to the non_embedding_partitions 788 for partition in non_embedding_partitions: 789 partition.logical_device_ids = occupied_devices 790 # Get the node to partition mapping 791 self.node_to_partition = get_node_to_partition_mapping(self.partitions) 792 return 793 794 def cost_aware_partition( 795 self, 796 transfer_rate_bytes_per_sec: float, 797 node_to_latency_mapping: Dict[Node, NodeLatency], 798 ) -> None: 799 """This method is to partition the fx module based on the cost. 800 The cost is the total latency of running the whole fx module. 801 In partitioner_utils.py, the cost model is built. 802 The cost aware partition algorithm is: 803 #1. At every beginning, each node is a partition. 804 Then we map all the partitions to the devices 805 and calculate the cost 806 #2. Then try to pre-combine any two of the partitions if the two 807 partitions can be combined. 808 (the bfs level is less than 2 or two partitions are connected and 809 can find partition to device mapping) 810 See if any partition pair could reduce the current cost. 811 Choose the pair that shows the minimum cost and then combine them 812 #3. Repeat #2 until the cost cannot be reduced. 813 """ 814 815 def try_combining_partitions(p0_index, p1_index, partitions) -> float: 816 """Given two partitions and a list of partitions, combine these two partitions 817 and see what is the cost of the modified partition list 818 """ 819 p0 = partitions[p0_index] 820 p1 = partitions[p1_index] 821 """If two partitions' bfs level are less than 2 or two partitions are connected to each other, 822 then they can be combined 823 """ 824 if ( 825 (abs(p0.bfs_level - p1.bfs_level) <= 1) 826 or (p0 in p1.parents) 827 or p0 in (p1.children) 828 ): 829 combine_two_partitions(p0, p1, partitions) 830 # Check if a circular dependency exists after combining 831 if check_dependency(partitions[-1]): 832 return float("inf") 833 # Check if the modified partition list can be mapped to devices after combination 834 reset_partition_device(partitions) 835 found_deivce = get_device_to_partitions_mapping( 836 partitions, self.devices 837 ) 838 if not found_deivce: 839 return float("inf") 840 # Calculate the new cost 841 partition_to_latency_mapping = get_partition_to_latency_mapping( 842 partitions, node_to_latency_mapping 843 ) 844 cost = get_latency_of_partitioned_graph( 845 partitions, 846 partition_to_latency_mapping, 847 transfer_rate_bytes_per_sec, 848 ) 849 return cost 850 # If two partition can not be combined, the cost is inf 851 return float("inf") 852 853 def search_combination( 854 transfer_rate_bytes_per_sec, node_to_latency_mapping 855 ) -> bool: 856 """Given transfer rate between partitions and each node's latency, 857 find two partitions to combine so the cost of the partitions can 858 be reduced. 859 The algorithm is : 860 1. Go through all the partition pairs and see 861 if any pair of partitions can be combined. 862 2. Calculate the cost after the combination. 863 3. Select the minimum cost and combine its corresponding partition pair. 864 """ 865 partition_to_latency_mapping = get_partition_to_latency_mapping( 866 self.partitions, node_to_latency_mapping 867 ) 868 cost = get_latency_of_partitioned_graph( 869 self.partitions, 870 partition_to_latency_mapping, 871 transfer_rate_bytes_per_sec, 872 ) 873 if len(self.partitions) == 1: 874 return False 875 partition_pair: List[int] = [] 876 for i in range(len(self.partitions) - 1): 877 for j in range(i + 1, len(self.partitions)): 878 # Try to combine the partition pair 879 # and see the new cost after combination 880 new_cost = try_combining_partitions(i, j, self.partitions[:]) 881 if new_cost <= cost: 882 partition_pair = [i, j] 883 cost = new_cost 884 reorganize_partitions(self.partitions) 885 # If a partition pair is found, combine them 886 if len(partition_pair) != 0: 887 p0 = self.partitions[partition_pair[0]] 888 p1 = self.partitions[partition_pair[1]] 889 combine_two_partitions(p0, p1, self.partitions) 890 get_bfs_level_partition(self.partitions) 891 reset_partition_device(self.partitions) 892 get_device_to_partitions_mapping(self.partitions, self.devices) 893 return len(partition_pair) != 0 894 895 for node in self.graph_module.graph.nodes: 896 if node.op not in {"placeholder", "get_attr", "output"}: 897 self.create_single_node_partition(node) 898 # Set up parent partitions and children partitions for each partition 899 set_parents_and_children(self.partitions) 900 # Get bfs level for each partition 901 get_bfs_level_partition(self.partitions) 902 find_combination = True 903 while find_combination: 904 # Search for a pair partition to generate the minimum new cost, 905 # then combine them 906 find_combination = search_combination( 907 transfer_rate_bytes_per_sec, node_to_latency_mapping 908 ) 909 # Make sure all partitions are set up correctly 910 reorganize_partitions(self.partitions) 911 # Set up node to partition mapping 912 self.node_to_partition = get_node_to_partition_mapping(self.partitions) 913 return 914 915 def kl_based_partition( 916 self, 917 transfer_rate_bytes_per_sec: float, 918 node_to_latency_mapping: Dict[Node, NodeLatency], 919 ) -> None: 920 """This function is a cost aware partition based 921 on Kernighan-Lin algorithm. 922 First, the graph is partitioned using size_based_partition. 923 Then, each node is swapped with any other node in a different 924 partition, and at the same time, the cost is estimated after 925 the swapping. 926 For example, we have nodes n0, n1, n2, n3 and n4. 927 Using size_based_partition, n0 and n1 are in Partition p0. 928 n2, n3 and n4 in Partition p1. The current cost is estimated. 929 We first tried using n0 to swap with n2 from the other partition. 930 Then we see that swapping n0 and n2 shows a lower cost 931 than the current cost and it is the minimum among other pairs like 932 (n0, None)(This means moving n0 to Partition without swapping other nodes), 933 (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost 934 as the current cost. 935 Then We repeat this process for all the other nodes until all swapping pairs 936 are tried. 937 """ 938 939 def swap_nodes(n0, n1, p0, p1): 940 # Either n0 or n1 could be None 941 # That means we simply move the node 942 # to another partition 943 if n0 is not None: 944 p0.remove_node(n0) 945 p1.add_node(n0) 946 if n1 is not None: 947 p0.add_node(n1) 948 p1.remove_node(n1) 949 950 def try_swap_nodes( 951 n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec 952 ): 953 cost = float("inf") 954 swap_nodes(n0, n1, p0, p1) 955 # Reorganize partitions after swapping 956 reorganize_partitions(self.partitions) 957 # Check if there is a circular dependency after swapping 958 if (not check_dependency(p0)) and (not check_dependency(p1)): 959 reset_partition_device(self.partitions) 960 partition_to_latency_mapping = get_partition_to_latency_mapping( 961 self.partitions, node_to_latency_mapping 962 ) 963 # Check if all partitions can be mapped to logical devices after swapping 964 found_device = get_device_to_partitions_mapping( 965 self.partitions, self.devices 966 ) 967 if not found_device: 968 cost = float("inf") 969 else: 970 cost = get_latency_of_partitioned_graph( 971 self.partitions, 972 partition_to_latency_mapping, 973 transfer_rate_bytes_per_sec, 974 ) 975 # Swap back and reset all partitions back to original 976 swap_nodes(n1, n0, p0, p1) 977 reorganize_partitions(self.partitions) 978 reset_partition_device(self.partitions) 979 get_device_to_partitions_mapping(self.partitions, self.devices) 980 return cost 981 982 def swap_node_to_partition( 983 node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec 984 ): 985 """This function helps to swap one node from partition p0 986 with all the nodes in another partition p1 987 """ 988 p1_nodes = list(p1.nodes) + [None] 989 min_cost = float("inf") 990 node_pair: List[Node] = [] 991 for n1 in p1_nodes: 992 # Ignore the node if it is not a op node 993 if n1 is not None and n1.op in {"placeholder", "get_attr"}: 994 continue 995 # Try swapping node in p0 with n1 in p1 996 cost = try_swap_nodes( 997 node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec 998 ) 999 if cost < min_cost: 1000 node_pair = [node, n1] 1001 min_cost = cost 1002 return cost, node_pair # type: ignore[possibly-undefined] 1003 1004 # First use size_base_partition 1005 self.size_based_partition() 1006 partition_to_latency_mapping = get_partition_to_latency_mapping( 1007 self.partitions, node_to_latency_mapping 1008 ) 1009 # Calculate the cost of the partitions 1010 cost = get_latency_of_partitioned_graph( 1011 self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec 1012 ) 1013 # Keep tracking the node pair that shows the better cost 1014 node_pair: List[Node] = [] 1015 # Keep tracking the partition pair of node pair 1016 partition_pair: List[Partition] = [] 1017 # Collect all the op nodes from the graph 1018 op_nodes = [] 1019 for n in self.graph_module.graph.nodes: 1020 if n.op not in {"placeholder", "get_attr", "output"}: 1021 op_nodes.append(n) 1022 for node in op_nodes: 1023 # Find which partition the current node belongs 1024 p0_index = self.node_to_partition[node] 1025 p0 = self.partitions[p0_index] 1026 # Go through all the other partitions to swap 1027 # with other nodes from those partitions 1028 for p1_index, _ in enumerate(self.partitions): 1029 if p0_index != p1_index: 1030 p1 = self.partitions[p1_index] 1031 new_cost, new_node_pair = swap_node_to_partition( 1032 node, 1033 p0, 1034 p1, 1035 node_to_latency_mapping, 1036 transfer_rate_bytes_per_sec, 1037 ) 1038 # Update the cost 1039 # Track the swapped node pair and their partitions 1040 if new_cost < cost: 1041 cost = new_cost 1042 node_pair = new_node_pair 1043 partition_pair = [p0, p1] 1044 # Do the swapping after trying all the nodes from a partition 1045 if len(node_pair) != 0: 1046 swap_nodes( 1047 node_pair[0], node_pair[1], partition_pair[0], partition_pair[1] 1048 ) 1049 reorganize_partitions(self.partitions) 1050 get_device_to_partitions_mapping(self.partitions, self.devices) 1051 reorganize_partitions(self.partitions) 1052 # Mapping the device to the partition 1053 get_device_to_partitions_mapping(self.partitions, self.devices) 1054 return 1055 1056 def aot_based_partition( 1057 self, node_to_partition_mapping, partition_to_logical_device_mapping 1058 ): 1059 """This function helps to rebuild the partitions given the nodes and its 1060 corresponding partition id 1061 """ 1062 partition_id_to_partition_mapping: Dict[int, Partition] = {} 1063 self.node_to_partition = node_to_partition_mapping 1064 for node in self.node_to_partition: 1065 partition_id = self.node_to_partition[node] 1066 # If the requested partition has not been created, create the partition 1067 if partition_id not in partition_id_to_partition_mapping: 1068 partition = Partition(partition_id) 1069 self.partitions.append(partition) 1070 partition_id_to_partition_mapping[partition_id] = partition 1071 partition.logical_device_ids = partition_to_logical_device_mapping[ 1072 partition_id 1073 ] 1074 else: 1075 partition = partition_id_to_partition_mapping[ 1076 self.node_to_partition[node] 1077 ] 1078 # Add the current node into the partition 1079 partition.add_node(node) 1080