xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/accelerator_partitioner.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import operator
3from collections import deque
4from typing import Dict, List, Set, NamedTuple, Tuple, Deque
5
6import torch
7from torch.fx.passes.graph_manipulation import get_size_of_all_nodes
8from torch.fx.experimental.partitioner_utils import (
9    Partition,
10    Device,
11    PartitionerConfig,
12    get_partition_to_latency_mapping,
13    get_latency_of_partitioned_graph,
14    NodeLatency,
15    get_extra_size_of,
16    PartitionMode,
17)
18from torch.fx.graph_module import GraphModule
19from torch.fx.node import Node, map_arg
20from torch.fx.passes.split_module import split_module
21
22
23class DAGNode:
24    """DAGNode class maintains useful information for a partition (submodule),
25    and its input submodules and output submodules.
26    """
27
28    def __init__(
29        self,
30        submodule_node: Node,
31        input_nodes: List[Node],
32        output_nodes: List[Node],
33        logical_device_ids: List[int],
34        size_bytes: int,
35    ) -> None:
36        self.submodule_node: Node = submodule_node
37        self.input_nodes: List[Node] = input_nodes
38        self.output_nodes: List[Node] = output_nodes
39        self.logical_device_ids: List[int] = logical_device_ids
40        self.size_bytes = size_bytes
41
42    def __str__(self) -> str:
43        return str(self.submodule_node)
44
45
46class DAG:
47    """DAG class contains all the DAG nodes"""
48
49    def __init__(self) -> None:
50        self.nodes: List[DAGNode] = []
51
52    def create_node(
53        self,
54        submodule_node: Node,
55        input_nodes: List[Node],
56        output_nodes: List[Node],
57        logical_devices: List[int],
58        size_bytes: int,
59    ) -> None:
60        node = DAGNode(
61            submodule_node, input_nodes, output_nodes, logical_devices, size_bytes
62        )
63        self.nodes.append(node)
64
65
66class PartitionResult(NamedTuple):
67    """NameTuple used for returning DAG and a new fx module"""
68
69    dag: DAG
70    module_with_submodules: GraphModule
71
72
73"""Followings are some helper functions for partition manipulation"""
74
75
76def reset_partition_device(partitions):
77    for partition in partitions:
78        partition.logical_device_ids = []
79
80
81def combine_two_partitions(
82    partition_0: Partition, partition_1: Partition, partitions: List[Partition]
83) -> None:
84    """Given a list of partitions and its two partitions,
85    combine these two partitions into a new one appending to the partitions
86    and remove the previous two partitions from the list of partitions
87    """
88    partition = Partition(len(partitions))
89    partition.nodes = partition_0.nodes.union(partition_1.nodes)
90    partition.recalculate_mem_size()
91    partitions.append(partition)
92    partitions.remove(partition_0)
93    partitions.remove(partition_1)
94    reorganize_partitions(partitions)
95    return
96
97
98def set_parents_and_children(partitions: List[Partition]) -> None:
99    """Given a list of partitions, mark parents and children for each partition"""
100    # Go through all nodes in a partition.
101    # If a node's user is in other partition,
102    # then the other partition is this partition's children.
103    # This partition is the other partition's parent
104    for partition in partitions:
105        partition.children = set()
106        partition.parents = set()
107    for partition in partitions:
108        for node in partition.nodes:
109            # For each node in the current partition, find its users
110            users = node.users
111            for n in users:
112                # Find which the partition the user node belongs to.
113                # Note that if the node itself is also belongs to that partition,
114                # that partition is not the child of the current partition
115                for p in partitions:
116                    if p != partition and n in p.nodes and node not in p.nodes:
117                        partition.children.add(p)
118                        p.parents.add(partition)
119    return
120
121
122def reorganize_partitions(partitions: List[Partition]) -> None:
123    """Given a list of partitions, reorganize partition id,
124    its parents and its children for each partition
125    """
126    # Rearrange partition ids
127    for i, partition in enumerate(partitions):
128        partition.partition_id = i
129    set_parents_and_children(partitions)
130    return
131
132
133def get_bfs_level_partition(partitions: List[Partition]) -> None:
134    """Given a list of partitions,
135    mark the bfs level for each partition
136    """
137    current_level: Set[Partition] = set()
138    visited: Set[Partition] = set()
139    for partition in partitions:
140        # If a partition has no parent, it should be in root level
141        if len(partition.parents) == 0:
142            current_level.add(partition)
143    next_level: Set[Partition] = set()
144    level = 0
145    # bfs
146    while current_level:
147        partition = current_level.pop()
148        partition.bfs_level = level
149        visited.add(partition)
150        children = partition.children
151        for child in children:
152            if child not in next_level:
153                next_level.add(child)
154        if not current_level:
155            current_level = next_level.copy()
156            next_level = set()
157            level += 1
158    return
159
160
161def get_node_to_partition_mapping(partitions: List[Partition]) -> Dict[Node, int]:
162    """Given a list of partitions,return node to partition mapping"""
163    node_to_partition: Dict[Node, int] = {}
164    for partition in partitions:
165        for node in partition.nodes:
166            node_to_partition[node] = partition.partition_id
167    return node_to_partition
168
169
170def get_logical_id_to_device(devices: List[Device]) -> Dict[int, Device]:
171    """Get a mapping from device logical ID to Device object."""
172    logical_id_to_device: Dict[int, Device] = {}
173    for d in devices:
174        logical_id_to_device[d.logical_id] = d
175    return logical_id_to_device
176
177
178def get_device_partition_stats(
179    partitions: List[Partition], devices: List[Device]
180) -> Tuple[Dict[Device, List[Partition]], Dict[Device, int], List[Partition]]:
181    """Given a list of partitions and a list of devices, returns:
182    1. A mapping from device to partitions on it;
183    2. A mapping from device to its remaining memory size;
184    3. A list of partitions that do not have a device.
185    """
186    # logical id to device
187    logical_id_to_device = get_logical_id_to_device(devices)
188    # Track partitions on device
189    device_to_partitions: Dict[Device, List[Partition]] = {}
190    # Track device's left mem size
191    device_to_left_mem_bytes: Dict[Device, int] = {}
192    for d in devices:
193        device_to_partitions[d] = []
194        device_to_left_mem_bytes[d] = d.available_mem_bytes
195
196    # Deal with the partitions that already have a device
197    # and also collect all partitions without a device (no_device_partitions)
198    no_device_partitions = []
199    for partition in partitions:
200        if partition.logical_device_ids != []:
201            for logical_id in partition.logical_device_ids:
202                device = logical_id_to_device[logical_id]
203                device_to_partitions[device].append(partition)
204                device_to_left_mem_bytes[device] -= partition.used_mem_bytes
205        else:
206            no_device_partitions.append(partition)
207
208    return (
209        device_to_partitions,
210        device_to_left_mem_bytes,
211        no_device_partitions,
212    )
213
214
215def get_device_to_partitions_mapping(
216    partitions: List[Partition], devices: List[Device]
217):
218    """Given a list of partitions and a list of devices,
219    map each partition into a device.
220    """
221
222    def calculate_extra_mem_bytes_needed_for(
223        partition: Partition, partitions: List[Partition]
224    ):
225        all_nodes: Set[Node] = set()
226        for p in partitions:
227            all_nodes = all_nodes.union(p.nodes)
228        if len(all_nodes) == 0:
229            return partition.used_mem_bytes
230        all_nodes = all_nodes.union(partition.nodes)
231        extra_size_needed = 0
232        for node in partition.nodes:
233            extra_size_needed += get_extra_size_of(node, all_nodes)
234        return extra_size_needed
235
236    def find_device_for(partition: Partition):
237        """Given a partition, find a logical device for the partition
238        The algorithm is to put the partition on the device
239        that has just enough mem left for that partition.
240        device_to_left_mem_bytes is a dictionary between device and its left mem size
241        sorted by its left mem size
242        """
243        for d in device_to_left_mem_bytes:
244            extra_size_needed = calculate_extra_mem_bytes_needed_for(
245                partition, device_to_partitions[d]
246            )
247            if extra_size_needed < device_to_left_mem_bytes[d]:
248                device_to_partitions[d].append(partition)
249                partition.logical_device_ids.append(d.logical_id)
250                device_to_left_mem_bytes[d] -= extra_size_needed
251                return True
252        return False
253
254    (
255        device_to_partitions,
256        device_to_left_mem_bytes,
257        no_device_partitions,
258    ) = get_device_partition_stats(partitions, devices)
259
260    # Find devices for all the partitions without a device
261    found_device = True
262    for partition in no_device_partitions:
263        device_to_left_mem_bytes = dict(sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)))
264        found_device = find_device_for(partition)
265        if not found_device:
266            break
267    return found_device
268
269
270def check_dependency(partition):
271    """Given a partition,check if there is a circular dependency on
272    this partition using bfs
273    """
274    visited: Set[Partition] = {partition}
275    queue: Deque[Partition] = deque([partition])
276    while queue:
277        p = queue.popleft()
278        for child in p.children:
279            if child == partition:
280                return True
281            else:
282                if child not in visited:
283                    visited.add(child)
284                    queue.append(child)
285    return False
286
287
288class Partitioner:
289    """A fx module may not fit into one device.
290    Partitioner class helps partition one fx module into submodules (partitions),
291    so that the submodules can be executed crossing different accelerators.
292    The main function of this class is self.partition_graph.
293    It partitions the fx module based on the scheme specified in partition_config
294    A DAG structure is returned
295    along with a new fx module with submodule nodes.
296    """
297
298    def __init__(self) -> None:
299        self.partitions: List[Partition] = []
300        self.node_to_partition: Dict[Node, int] = {}
301        self.devices: List[Device] = []
302
303    def partition_graph(
304        self,
305        fx_module: GraphModule,
306        torch_module: torch.nn.Module,
307        partitioner_config: PartitionerConfig,
308    ) -> PartitionResult:
309        """Given the fx module, torch module and partitioner_config,
310        find the partitions, do the partitions,
311        and then return a DAG and a new fx module with submodule nodes (partitions)
312        """
313        self.graph_module = fx_module
314        self.torch_module = torch_module
315        self.devices = partitioner_config.devices
316        if len(self.devices) == 0:
317            raise RuntimeError("No devices")
318        # Tag the size in bytes to all nodes in the graph_module.
319        get_size_of_all_nodes(self.graph_module)
320        # Check if there are op nodes in the fx module
321        nodes = self.graph_module.graph.nodes
322        if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes):
323            raise RuntimeError("No Partition since no operations in the module")
324        # Calculate total size of the fx module
325        total_size_of_graph = 0
326        for node in nodes:
327            if node.op == "output":
328                break
329            total_size_of_graph += node.size_bytes.total_size
330        # Find the device with the max mem size
331        device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes)
332        # AOT based partition
333        if partitioner_config.mode == PartitionMode.aot_based:
334            self.aot_based_partition(
335                partitioner_config.node_to_partition_mapping,
336                partitioner_config.partition_to_logical_device_mapping,
337            )
338        # Single partition if the whole module can be fit into one device
339        elif total_size_of_graph <= device_with_max_mem.available_mem_bytes:
340            self.find_single_partition(
341                total_size_of_graph, logical_device_id=device_with_max_mem.logical_id
342            )
343        elif total_size_of_graph > sum(d.available_mem_bytes for d in self.devices):
344            raise RuntimeError("Devices have no enough memory for the module")
345        else:
346            # Sparse nn based partition
347            if partitioner_config.mode == PartitionMode.sparse_nn:
348                available_mem_bytes = self.devices[0].available_mem_bytes
349                if not all(
350                    device.available_mem_bytes == available_mem_bytes
351                    for device in self.devices
352                ):
353                    raise RuntimeError("All devices must have same memory size!")
354                # sparse_nn_partition only support same memory size
355                # TODO: add different size support for sparse_nn_partition
356                self.sparse_nn_partition(available_mem_bytes)
357            # Cost aware partition
358            elif partitioner_config.mode == PartitionMode.cost_aware:
359                self.cost_aware_partition(
360                    partitioner_config.transfer_rate_bytes_per_sec,
361                    partitioner_config.node_to_latency_mapping,
362                )
363            # KL based partition
364            elif partitioner_config.mode == PartitionMode.kl_based:
365                self.kl_based_partition(
366                    partitioner_config.transfer_rate_bytes_per_sec,
367                    partitioner_config.node_to_latency_mapping,
368                )
369            else:
370                self.size_based_partition()
371
372        # Saturate host if possible.
373        if partitioner_config.saturate_host:
374            self.saturate_host()
375
376        # Partition the graph module based on the partition assignment.
377        module_with_submodules = self.do_partition()
378
379        # The DAG contains DAGNodes with info of each partition's input nodes, output nodes
380        # and how partitions are connected.
381        dag = self.dump_dag(module_with_submodules)
382        ret = PartitionResult(dag, module_with_submodules)
383        return ret
384
385    def find_single_partition(
386        self, total_size_of_graph, logical_device_id: int = 0
387    ) -> None:
388        """Fit the whole fx module into one device"""
389        partition_0 = self.create_partition()
390        for node in self.graph_module.graph.nodes:
391            if node.op == "output":
392                # Skip the output node, but there can
393                # be nodes after the output in certain cases.
394                continue
395            partition_0.nodes.add(node)
396        partition_0.used_mem_bytes = total_size_of_graph
397        partition_0.logical_device_ids = [logical_device_id]
398        # Get the node to partition mapping
399        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
400        return
401
402    def size_based_partition(self) -> None:
403        """This method is to partition the fx module based on memory size.
404        It uses greedy approach. The result may not be the best.
405        The basic idea is:
406        Step 1:
407        Find a device which has enough memory to fit the current node, create a empty partition
408        with the size of that device.
409        Then keep adding the following nodes into the partition until the partition is full.
410        Step 2:
411        Repeat Step 1 until no device left
412        Step 3:
413        If some nodes are left, create a partition for each left node (single node partition).
414        and then try to map those partitions into logical devices with enough mem left.
415        """
416
417        def find_device_based_on_size(node) -> Device:
418            """Given a node, this function is to find a logical device
419            that could fit the node.
420            """
421            mem_size_needed = get_extra_size_of(node, set())
422            device = Device("", -1, -1)
423            for d in self.devices:
424                if (
425                    d not in occupied_devices
426                    and d.available_mem_bytes >= mem_size_needed
427                ):
428                    device = d
429                    break
430            if device.available_mem_bytes < 0:
431                raise RuntimeError(str(node) + "is too large to fit any device")
432            occupied_devices.append(device)
433            return device
434
435        # Track partition and its left mem size
436        partition_to_left_mem_bytes: Dict[Partition, int] = {}
437        # Track all the devices that have been used
438        occupied_devices: List[Device] = []
439        partition = self.create_partition()
440        for node in self.graph_module.graph.nodes:
441            if node.op in {"call_module", "call_method", "call_function"}:
442                # Check if there are devices left
443                if len(self.partitions) <= len(self.devices):
444                    total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
445                    # Check if the current partition is the very first partition
446                    if partition.used_mem_bytes == 0:
447                        # Find a device to fit the first node, return available mem size
448                        device = find_device_based_on_size(node)
449                        occupied_devices.append(device)
450                        # Update partition and its left mem size
451                        partition_to_left_mem_bytes[
452                            partition
453                        ] = device.available_mem_bytes
454                        # Update available mem for the current partition
455                        partition.logical_device_ids.append(device.logical_id)
456                    else:
457                        # The current partition is not the first partition
458                        # Check if the current node can fit into current partition
459                        if (
460                            partition_to_left_mem_bytes[partition]
461                            < total_size_of_input_nodes
462                        ):
463                            # Check if no device is left
464                            if len(self.partitions) == len(self.devices):
465                                # No device is left
466                                # Put the previous partitions into a list (non_single_node_partitions)
467                                non_single_node_partitions = self.partitions[:]
468                                # Create the first single node partition for the current node
469                                self.create_single_node_partition(node)
470                                continue
471                            # Some devices are still left
472                            # Create a new partition with a mem size that is enough for the current node
473                            device = find_device_based_on_size(node)
474                            partition = self.create_partition()
475                            total_size_of_input_nodes = get_extra_size_of(
476                                node, partition.nodes
477                            )
478                            partition_to_left_mem_bytes[
479                                partition
480                            ] = device.available_mem_bytes
481                            partition.logical_device_ids.append(device.logical_id)
482                    partition.add_node(node)
483                    partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes
484                # Create single node partitions if no device is left
485                else:
486                    self.create_single_node_partition(node)
487        reorganize_partitions(self.partitions)
488        # Get the node to partition mapping
489        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
490        # Mapping all partitions into device
491        found_partition_to_device_mapping = get_device_to_partitions_mapping(
492            self.partitions, self.devices
493        )
494        if not found_partition_to_device_mapping:
495            raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping")
496        return
497
498    def saturate_host(self) -> None:
499        """Saturate host by assigning replicates to unused devices with enough memory.
500        It uses a greedy approach to find a next available set of devices to place all split
501        partitions: For each used device, it searches for an idle device with minimal memory
502        size that can hold all the partition located on that device; If the search is successful
503        for all used devices, it then assigns the new devices' logical ID to the corresponding
504        partition.
505        """
506        (
507            device_to_partitions,
508            device_to_left_mem_bytes,
509            no_device_partitions,
510        ) = get_device_partition_stats(self.partitions, self.devices)
511
512        assert (
513            len(no_device_partitions) == 0
514        ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}"
515
516        # Devices that hold partitions
517        used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0]
518        # Track replicates of the assigned devices
519        replicated_device_to_used_device: Dict[Device, Device] = {}
520
521        while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len(
522            self.devices
523        ):
524            # Success flag for this round
525            success = True
526            # Devices that have not been assigned
527            idle_devices = [
528                d
529                for d in self.devices
530                if d not in used_devices and d not in replicated_device_to_used_device
531            ]
532            # Temporary mapping from replicated device to original device
533            temp_replicate_mapping = {}
534
535            # Find a new device to replicate all partitions on an used device
536            for used_device in used_devices:
537                # Idle devices that have enough memory
538                available_devices = [
539                    d
540                    for d in idle_devices
541                    if d.available_mem_bytes
542                    >= used_device.available_mem_bytes
543                    - device_to_left_mem_bytes[used_device]
544                ]
545                if len(available_devices) == 0:
546                    success = False
547                    break
548                new_device = min(available_devices, key=lambda d: d.available_mem_bytes)
549                idle_devices.remove(new_device)
550                temp_replicate_mapping[new_device] = used_device
551
552            if not success:
553                break
554            replicated_device_to_used_device.update(temp_replicate_mapping)
555
556        # Update logical device IDs assigned to the partitions
557        for (
558            replicate_device,
559            original_device,
560        ) in replicated_device_to_used_device.items():
561            logical_id = replicate_device.logical_id
562            for partition in device_to_partitions[original_device]:
563                partition.logical_device_ids.append(logical_id)
564        for p in self.partitions:
565            print(p.logical_device_ids)
566
567    def do_partition(self) -> GraphModule:
568        """Return a new fx module with submodule nodes (partitions)."""
569        module_with_submodules = split_module(
570            self.graph_module,
571            self.torch_module,
572            lambda node: self.node_to_partition[node],
573        )
574        return module_with_submodules
575
576    def dump_dag(self, module_with_submodules: GraphModule) -> DAG:
577        """Return the dag structure and the new fx module with submodules."""
578        dag = DAG()
579        for node in module_with_submodules.graph.nodes:
580            if node.op == "output":
581                break
582            if node.op in {"placeholder", "get_attr"}:
583                continue
584            if node.target == operator.__getitem__:
585                continue
586            input_nodes: Dict[Node, None] = {}
587            map_arg(node.args, input_nodes.setdefault)
588            map_arg(node.kwargs, input_nodes.setdefault)
589            # When a node has two or more output nodes,
590            # it outputs its result to 'getitem' nodes.
591            # Those 'getitem' nodes are the output node for this node.
592            # Otherwise, the output node is this node itself.
593            if len(node.users) > 1:
594                output_nodes = list(node.users)
595            else:
596                output_nodes = [node]
597            partition_id = int(node.name.rsplit("_", 1)[-1])
598            device_ids = self.partitions[partition_id].logical_device_ids
599            size_bytes = self.partitions[partition_id].used_mem_bytes
600            dag.create_node(
601                node, list(input_nodes), output_nodes, device_ids, size_bytes
602            )
603        return dag
604
605    def create_partition(self) -> Partition:
606        """Create a partition and append it to self.partitions."""
607        partition_id = len(self.partitions)
608        partition = Partition(partition_id)
609        self.partitions.append(partition)
610        return partition
611
612    def create_single_node_partition(self, node):
613        """Create a partition for a single node"""
614        partition = self.create_partition()
615        partition.add_node(node)
616        return
617
618    def sparse_nn_partition(self, available_mem_bytes: int) -> None:
619        """This method partition a sparse nn module.
620        It is size based partition but different from size_based_partition,
621        it only works when all the devices have same memory size (available_mem_bytes).
622        In the future, devices with different mem sizes will be supported like size_based_partition.
623        It first traverse all the nodes and do the partitions based on the same memory size.
624        If the current partition has no enough memory left for a new op node
625        (call_module, call_method, call_function), a new partition is created.
626        When crossing the boundary between non-embedding nodes and embedding nodes,
627        a new partition is created regardlessly.
628        For example, if the current node is a non-embedding node but the next node is an
629        embedding node, a new partition is created for the next node.
630        After the partition, the partitions are combined as much as possible.
631        The rule is that a non-embedding partition only
632        combines with another non-embedding one.
633        So as the embedding partitions.
634        """
635
636        def combine_partitions_based_on_size(
637            partitions: List[Partition], available_mem_bytes: int
638        ) -> None:
639            """Combining small partitions together to keep as less partitions as possible.
640            Here is an example of the algorithm to do this:
641            Assume some partitions, we first sort them based on partition used memory size.
642            [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)]
643            The available memory is 10.
644            step 1: self.find_partition_to_combine_based_on_size()
645            First, mark bfs level for each partition
646            Second, look the smallest partition, partition_4: 10 - 1 = 9
647            It means any partition has a used memory equal or less than 9 could combine this partition
648            We go from the largest and selection partition_0.
649            Check the bfs level for two partitions, if the level difference is less than 2,
650            it can be combined.
651            step 2: repeat step 1 until no partitions can be combined
652            """
653            find_combination = True
654            while find_combination:
655                # Sort partitions based on memory size
656                sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes)
657                # Mark bfs level
658                get_bfs_level_partition(self.partitions)
659                find_combination, partitions = find_partition_to_combine_based_on_size(
660                    sorted_partitions, available_mem_bytes, partitions
661                )
662            return
663
664        def calculate_mem_bytes_needed(p1, p2):
665            """Given two partitions, calculate how many mem bytes
666            are needed if two partitions are combined
667            """
668            nodes = p1.nodes.union(p2.nodes)
669            mem_bytes_needed = 0
670            for node in nodes:
671                mem_bytes_needed += get_extra_size_of(node, nodes)
672            return mem_bytes_needed
673
674        def find_partition_to_combine_based_on_size(
675            sorted_partitions: List[Partition],
676            available_mem_bytes: int,
677            partitions: List[Partition],
678        ) -> Tuple[bool, List[Partition]]:
679            """step 1 in combine_partition_based_on_size()"""
680            find_combination = False
681            smallest_partition = sorted_partitions.pop(0)
682            for p in sorted_partitions[::-1]:
683                if abs(smallest_partition.bfs_level - p.bfs_level) <= 1:
684                    # Calculate how many bytes needed if combined
685                    mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition)
686                    if mem_bytes_needed <= available_mem_bytes:
687                        combine_two_partitions(p, smallest_partition, self.partitions)
688                        partitions.remove(smallest_partition)
689                        partitions.remove(p)
690                        partitions.append(self.partitions[-1])
691                        find_combination = True
692                        break
693            return find_combination, partitions
694
695        def reset_partition_in_sparse_nn(partition, new_partition=True):
696            """If crossing the boundary between non-embedding nodes and
697            embedding nodes, create a new partition
698            """
699            if in_embedding_region:
700                embedding_partitions.append(partition)
701            else:
702                non_embedding_partitions.append(partition)
703            if new_partition:
704                partition = self.create_partition()
705                partition.left_mem_bytes = available_mem_bytes
706                return partition
707            return None
708
709        def is_embedding_node(node: Node) -> bool:
710            """Check if a node is an embedding node"""
711            if node.op == "call_module":
712                submodule = self.graph_module
713                for atom in str(node.target).split("."):
714                    if not hasattr(submodule, atom):
715                        raise RuntimeError(
716                            f"Module {submodule} has no attribute {atom}"
717                        )
718                    submodule = getattr(submodule, atom)
719                    if "Embedding" in str(submodule):
720                        return True
721            return False
722
723        # Track embedding partitions and non-embedding partitions separately
724        embedding_partitions: List[Partition] = []
725        non_embedding_partitions: List[Partition] = []
726        # A Flag to check the boundary
727        in_embedding_region: bool = False
728        partition = self.create_partition()
729        for node in self.graph_module.graph.nodes:
730            if node.op in {"call_module", "call_method", "call_function"}:
731                # Check if crossing the boundary between embedding nodes and non embedding nodes
732                if is_embedding_node(node) != in_embedding_region:
733                    # Crossing the boundary
734                    # Check if the current partition is an empty partition
735                    if partition.used_mem_bytes != 0:
736                        # The current partition isn't an empty partition. Create a new one.
737                        partition = reset_partition_in_sparse_nn(partition)
738                    in_embedding_region = not in_embedding_region
739                total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
740                if (
741                    total_size_of_input_nodes + partition.used_mem_bytes
742                    > available_mem_bytes
743                ):
744                    partition = reset_partition_in_sparse_nn(partition)
745                    total_size_of_input_nodes = get_extra_size_of(node, partition.nodes)
746                    if total_size_of_input_nodes > available_mem_bytes:
747                        raise RuntimeError(
748                            node.target + "is too large to fit into a device"
749                        )
750                partition.add_node(node)
751        reset_partition_in_sparse_nn(partition, new_partition=False)
752        # Set parents and children for partitions
753        set_parents_and_children(self.partitions)
754        # Combining non-embedding partitions
755        combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes)
756        # Combining embedding partitions
757        combine_partitions_based_on_size(embedding_partitions, available_mem_bytes)
758        total_size_of_non_embedding_partitions = 0
759        for partition in non_embedding_partitions:
760            total_size_of_non_embedding_partitions += partition.used_mem_bytes
761        # Check if devices are enough for all partitions
762        if len(embedding_partitions) > len(self.devices):
763            msg = (
764                "Need "
765                + str(len(embedding_partitions))
766                + " devices, but only "
767                + str(len(self.devices))
768                + " provided"
769            )
770            raise RuntimeError(msg)
771        occupied_devices = []
772        for i, partition in enumerate(embedding_partitions):
773            # Check if all non-embedding partitions can fit into embedding partition devices
774            if (
775                total_size_of_non_embedding_partitions + partition.used_mem_bytes
776                > available_mem_bytes
777            ):
778                raise RuntimeError(
779                    "partition_"
780                    + str(partition.partition_id)
781                    + "(embedding partition) and non embedding partitions can not fit into one device"
782                )
783            else:
784                # Add logical device to the partition
785                partition.logical_device_ids = [self.devices[i].logical_id]
786                occupied_devices.append(self.devices[i].logical_id)
787        # Add logical devices to the non_embedding_partitions
788        for partition in non_embedding_partitions:
789            partition.logical_device_ids = occupied_devices
790        # Get the node to partition mapping
791        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
792        return
793
794    def cost_aware_partition(
795        self,
796        transfer_rate_bytes_per_sec: float,
797        node_to_latency_mapping: Dict[Node, NodeLatency],
798    ) -> None:
799        """This method is to partition the fx module based on the cost.
800        The cost is the total latency of running the whole fx module.
801        In partitioner_utils.py, the cost model is built.
802        The cost aware partition algorithm is:
803        #1. At every beginning, each node is a partition.
804            Then we map all the partitions to the devices
805            and calculate the cost
806        #2. Then try to pre-combine any two of the partitions if the two
807            partitions can be combined.
808            (the bfs level is less than 2 or two partitions are connected and
809            can find partition to device mapping)
810            See if any partition pair could reduce the current cost.
811            Choose the pair that shows the minimum cost and then combine them
812        #3. Repeat #2 until the cost cannot be reduced.
813        """
814
815        def try_combining_partitions(p0_index, p1_index, partitions) -> float:
816            """Given two partitions and a list of partitions, combine these two partitions
817            and see what is the cost of the modified partition list
818            """
819            p0 = partitions[p0_index]
820            p1 = partitions[p1_index]
821            """If two partitions' bfs level are less than 2 or two partitions are connected to each other,
822               then they can be combined
823            """
824            if (
825                (abs(p0.bfs_level - p1.bfs_level) <= 1)
826                or (p0 in p1.parents)
827                or p0 in (p1.children)
828            ):
829                combine_two_partitions(p0, p1, partitions)
830                # Check if a circular dependency exists after combining
831                if check_dependency(partitions[-1]):
832                    return float("inf")
833                # Check if the modified partition list can be mapped to devices after combination
834                reset_partition_device(partitions)
835                found_deivce = get_device_to_partitions_mapping(
836                    partitions, self.devices
837                )
838                if not found_deivce:
839                    return float("inf")
840                # Calculate the new cost
841                partition_to_latency_mapping = get_partition_to_latency_mapping(
842                    partitions, node_to_latency_mapping
843                )
844                cost = get_latency_of_partitioned_graph(
845                    partitions,
846                    partition_to_latency_mapping,
847                    transfer_rate_bytes_per_sec,
848                )
849                return cost
850            # If two partition can not be combined, the cost is inf
851            return float("inf")
852
853        def search_combination(
854            transfer_rate_bytes_per_sec, node_to_latency_mapping
855        ) -> bool:
856            """Given transfer rate between partitions and each node's latency,
857            find two partitions to combine so the cost of the partitions can
858            be reduced.
859            The algorithm is :
860            1. Go through all the partition pairs and see
861            if any pair of partitions can be combined.
862            2. Calculate the cost after the combination.
863            3. Select the minimum cost and combine its corresponding partition pair.
864            """
865            partition_to_latency_mapping = get_partition_to_latency_mapping(
866                self.partitions, node_to_latency_mapping
867            )
868            cost = get_latency_of_partitioned_graph(
869                self.partitions,
870                partition_to_latency_mapping,
871                transfer_rate_bytes_per_sec,
872            )
873            if len(self.partitions) == 1:
874                return False
875            partition_pair: List[int] = []
876            for i in range(len(self.partitions) - 1):
877                for j in range(i + 1, len(self.partitions)):
878                    # Try to combine the partition pair
879                    # and see the new cost after combination
880                    new_cost = try_combining_partitions(i, j, self.partitions[:])
881                    if new_cost <= cost:
882                        partition_pair = [i, j]
883                        cost = new_cost
884                    reorganize_partitions(self.partitions)
885            # If a partition pair is found, combine them
886            if len(partition_pair) != 0:
887                p0 = self.partitions[partition_pair[0]]
888                p1 = self.partitions[partition_pair[1]]
889                combine_two_partitions(p0, p1, self.partitions)
890            get_bfs_level_partition(self.partitions)
891            reset_partition_device(self.partitions)
892            get_device_to_partitions_mapping(self.partitions, self.devices)
893            return len(partition_pair) != 0
894
895        for node in self.graph_module.graph.nodes:
896            if node.op not in {"placeholder", "get_attr", "output"}:
897                self.create_single_node_partition(node)
898        # Set up parent partitions and children partitions for each partition
899        set_parents_and_children(self.partitions)
900        # Get bfs level for each partition
901        get_bfs_level_partition(self.partitions)
902        find_combination = True
903        while find_combination:
904            # Search for a pair partition to generate the minimum new cost,
905            # then combine them
906            find_combination = search_combination(
907                transfer_rate_bytes_per_sec, node_to_latency_mapping
908            )
909        # Make sure all partitions are set up correctly
910        reorganize_partitions(self.partitions)
911        # Set up node to partition mapping
912        self.node_to_partition = get_node_to_partition_mapping(self.partitions)
913        return
914
915    def kl_based_partition(
916        self,
917        transfer_rate_bytes_per_sec: float,
918        node_to_latency_mapping: Dict[Node, NodeLatency],
919    ) -> None:
920        """This function is a cost aware partition based
921        on Kernighan-Lin algorithm.
922        First, the graph is partitioned using size_based_partition.
923        Then, each node is swapped with any other node in a different
924        partition, and at the same time, the cost is estimated after
925        the swapping.
926        For example, we have nodes n0, n1, n2, n3 and n4.
927        Using size_based_partition, n0 and n1 are in Partition p0.
928        n2, n3 and n4 in Partition p1. The current cost is estimated.
929        We first tried using n0 to swap with n2 from the other partition.
930        Then we see that swapping n0 and n2 shows a lower cost
931        than the current cost and it is the minimum among other pairs like
932        (n0, None)(This means moving n0 to Partition without swapping other nodes),
933        (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost
934        as the current cost.
935        Then We repeat this process for all the other nodes until all swapping pairs
936        are tried.
937        """
938
939        def swap_nodes(n0, n1, p0, p1):
940            # Either n0 or n1 could be None
941            # That means we simply move the node
942            # to another partition
943            if n0 is not None:
944                p0.remove_node(n0)
945                p1.add_node(n0)
946            if n1 is not None:
947                p0.add_node(n1)
948                p1.remove_node(n1)
949
950        def try_swap_nodes(
951            n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
952        ):
953            cost = float("inf")
954            swap_nodes(n0, n1, p0, p1)
955            # Reorganize partitions after swapping
956            reorganize_partitions(self.partitions)
957            # Check if there is a circular dependency after swapping
958            if (not check_dependency(p0)) and (not check_dependency(p1)):
959                reset_partition_device(self.partitions)
960                partition_to_latency_mapping = get_partition_to_latency_mapping(
961                    self.partitions, node_to_latency_mapping
962                )
963                # Check if all partitions can be mapped to logical devices after swapping
964                found_device = get_device_to_partitions_mapping(
965                    self.partitions, self.devices
966                )
967                if not found_device:
968                    cost = float("inf")
969                else:
970                    cost = get_latency_of_partitioned_graph(
971                        self.partitions,
972                        partition_to_latency_mapping,
973                        transfer_rate_bytes_per_sec,
974                    )
975            # Swap back and reset all partitions back to original
976            swap_nodes(n1, n0, p0, p1)
977            reorganize_partitions(self.partitions)
978            reset_partition_device(self.partitions)
979            get_device_to_partitions_mapping(self.partitions, self.devices)
980            return cost
981
982        def swap_node_to_partition(
983            node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
984        ):
985            """This function helps to swap one node from partition p0
986            with all the nodes in another partition p1
987            """
988            p1_nodes = list(p1.nodes) + [None]
989            min_cost = float("inf")
990            node_pair: List[Node] = []
991            for n1 in p1_nodes:
992                # Ignore the node if it is not a op node
993                if n1 is not None and n1.op in {"placeholder", "get_attr"}:
994                    continue
995                # Try swapping node in p0 with n1 in p1
996                cost = try_swap_nodes(
997                    node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
998                )
999                if cost < min_cost:
1000                    node_pair = [node, n1]
1001                    min_cost = cost
1002            return cost, node_pair  # type: ignore[possibly-undefined]
1003
1004        # First use size_base_partition
1005        self.size_based_partition()
1006        partition_to_latency_mapping = get_partition_to_latency_mapping(
1007            self.partitions, node_to_latency_mapping
1008        )
1009        # Calculate the cost of the partitions
1010        cost = get_latency_of_partitioned_graph(
1011            self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec
1012        )
1013        # Keep tracking the node pair that shows the better cost
1014        node_pair: List[Node] = []
1015        # Keep tracking the partition pair of node pair
1016        partition_pair: List[Partition] = []
1017        # Collect all the op nodes from the graph
1018        op_nodes = []
1019        for n in self.graph_module.graph.nodes:
1020            if n.op not in {"placeholder", "get_attr", "output"}:
1021                op_nodes.append(n)
1022        for node in op_nodes:
1023            # Find which partition the current node belongs
1024            p0_index = self.node_to_partition[node]
1025            p0 = self.partitions[p0_index]
1026            # Go through all the other partitions to swap
1027            # with other nodes from those partitions
1028            for p1_index, _ in enumerate(self.partitions):
1029                if p0_index != p1_index:
1030                    p1 = self.partitions[p1_index]
1031                    new_cost, new_node_pair = swap_node_to_partition(
1032                        node,
1033                        p0,
1034                        p1,
1035                        node_to_latency_mapping,
1036                        transfer_rate_bytes_per_sec,
1037                    )
1038                    # Update the cost
1039                    # Track the swapped node pair and their partitions
1040                    if new_cost < cost:
1041                        cost = new_cost
1042                        node_pair = new_node_pair
1043                        partition_pair = [p0, p1]
1044            # Do the swapping after trying all the nodes from a partition
1045            if len(node_pair) != 0:
1046                swap_nodes(
1047                    node_pair[0], node_pair[1], partition_pair[0], partition_pair[1]
1048                )
1049                reorganize_partitions(self.partitions)
1050                get_device_to_partitions_mapping(self.partitions, self.devices)
1051        reorganize_partitions(self.partitions)
1052        # Mapping the device to the partition
1053        get_device_to_partitions_mapping(self.partitions, self.devices)
1054        return
1055
1056    def aot_based_partition(
1057        self, node_to_partition_mapping, partition_to_logical_device_mapping
1058    ):
1059        """This function helps to rebuild the partitions given the nodes and its
1060        corresponding partition id
1061        """
1062        partition_id_to_partition_mapping: Dict[int, Partition] = {}
1063        self.node_to_partition = node_to_partition_mapping
1064        for node in self.node_to_partition:
1065            partition_id = self.node_to_partition[node]
1066            # If the requested partition has not been created, create the partition
1067            if partition_id not in partition_id_to_partition_mapping:
1068                partition = Partition(partition_id)
1069                self.partitions.append(partition)
1070                partition_id_to_partition_mapping[partition_id] = partition
1071                partition.logical_device_ids = partition_to_logical_device_mapping[
1072                    partition_id
1073                ]
1074            else:
1075                partition = partition_id_to_partition_mapping[
1076                    self.node_to_partition[node]
1077                ]
1078            # Add the current node into the partition
1079            partition.add_node(node)
1080