xref: /aosp_15_r20/external/executorch/extension/llm/custom_ops/model_sharding.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6# pyre-ignore-all-errors
7import re
8from typing import List
9
10import torch
11
12from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS
13from executorch.exir.dialects._ops import ops as exir_ops
14from executorch.exir.pass_base import ExportPass, PassResult
15from torch.export.exported_program import ExportedProgram
16from torch.library import impl, Library
17
18
19fallback_op_lib = Library("llama", "DEF")
20# registering an operator.
21fallback_op_lib.define("fallback(Tensor input) -> Tensor")
22
23
24@impl(fallback_op_lib, "fallback")
25def fallback_impl(a: torch.Tensor) -> torch.Tensor:
26    return a
27
28
29# registering the out variant.
30fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)")
31
32
33@impl(fallback_op_lib, "fallback.out")
34def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor:
35    out.copy_(a)
36    return out
37
38
39class SplitGraph(ExportPass):
40    """
41    Class to split the model to multiple partitions.
42    Because there is limited memory on the device, it could
43    not load all llama model in one pte.
44    """
45
46    def __init__(self, shard_layers: List[int]):
47        super().__init__()
48        self.shard_layers = shard_layers
49
50    def _insert_fallback_op(
51        self, graph_module: torch.fx.GraphModule
52    ) -> torch.fx.GraphModule:
53        """
54        Insert fallback op before layer that needs to be shard.
55        Example:
56            There is 12 layers llama model and num_sharding is 3.
57            The first partition will contain layers [0, 4) and embedding.
58            The second partition will contain layers [4, 8).
59            The third partition will contain layers [8, 12) and output.
60        """
61        pattern = r"layers.(\d+)"
62        prev_node = None
63        prev_layer = None
64        for node in graph_module.graph.nodes:
65            if node.op != "call_function" or "nn_module_stack" not in node.meta:
66                continue
67
68            module_values_list = list(node.meta["nn_module_stack"].values())
69            full_qualified_name = module_values_list[-1][0]
70            # Search which layer this node belongs to
71            match = re.search(pattern, full_qualified_name)
72            if match is None:
73                continue
74
75            cur_layer = int(match.group(1))
76            # Check the current node which is the last node of the layer
77            if cur_layer in self.shard_layers and prev_layer == cur_layer - 1:
78                with graph_module.graph.inserting_after(prev_node):
79                    users = list(prev_node.users.keys())
80                    inserted_node = graph_module.graph.create_node(
81                        "call_function",
82                        exir_ops.edge.llama.fallback.default,
83                        (prev_node,),
84                    )
85                    inserted_node.meta["val"] = prev_node.meta["val"]
86                    if prev_node.meta.get(QCOM_QUANT_ATTRS, None):
87                        inserted_node.meta[QCOM_QUANT_ATTRS] = prev_node.meta[
88                            QCOM_QUANT_ATTRS
89                        ]
90                    for user in users:
91                        user.replace_input_with(prev_node, inserted_node)
92
93            prev_layer = cur_layer
94            prev_node = node
95
96    def call(self, graph_module: torch.fx.GraphModule):
97        self._insert_fallback_op(graph_module)
98        graph_module.recompile()
99        return PassResult(graph_module, True)
100
101
102def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int):
103    graph_module = edge_program.graph_module
104    shard_layers = list(range(0, num_layers, int(num_layers / shares)))
105    return SplitGraph(shard_layers)(graph_module)
106