1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6# pyre-ignore-all-errors 7import re 8from typing import List 9 10import torch 11 12from executorch.backends.qualcomm.utils.constants import QCOM_QUANT_ATTRS 13from executorch.exir.dialects._ops import ops as exir_ops 14from executorch.exir.pass_base import ExportPass, PassResult 15from torch.export.exported_program import ExportedProgram 16from torch.library import impl, Library 17 18 19fallback_op_lib = Library("llama", "DEF") 20# registering an operator. 21fallback_op_lib.define("fallback(Tensor input) -> Tensor") 22 23 24@impl(fallback_op_lib, "fallback") 25def fallback_impl(a: torch.Tensor) -> torch.Tensor: 26 return a 27 28 29# registering the out variant. 30fallback_op_lib.define("fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)") 31 32 33@impl(fallback_op_lib, "fallback.out") 34def fallback_out_impl(a: torch.Tensor, *, out: torch.Tensor) -> torch.Tensor: 35 out.copy_(a) 36 return out 37 38 39class SplitGraph(ExportPass): 40 """ 41 Class to split the model to multiple partitions. 42 Because there is limited memory on the device, it could 43 not load all llama model in one pte. 44 """ 45 46 def __init__(self, shard_layers: List[int]): 47 super().__init__() 48 self.shard_layers = shard_layers 49 50 def _insert_fallback_op( 51 self, graph_module: torch.fx.GraphModule 52 ) -> torch.fx.GraphModule: 53 """ 54 Insert fallback op before layer that needs to be shard. 55 Example: 56 There is 12 layers llama model and num_sharding is 3. 57 The first partition will contain layers [0, 4) and embedding. 58 The second partition will contain layers [4, 8). 59 The third partition will contain layers [8, 12) and output. 60 """ 61 pattern = r"layers.(\d+)" 62 prev_node = None 63 prev_layer = None 64 for node in graph_module.graph.nodes: 65 if node.op != "call_function" or "nn_module_stack" not in node.meta: 66 continue 67 68 module_values_list = list(node.meta["nn_module_stack"].values()) 69 full_qualified_name = module_values_list[-1][0] 70 # Search which layer this node belongs to 71 match = re.search(pattern, full_qualified_name) 72 if match is None: 73 continue 74 75 cur_layer = int(match.group(1)) 76 # Check the current node which is the last node of the layer 77 if cur_layer in self.shard_layers and prev_layer == cur_layer - 1: 78 with graph_module.graph.inserting_after(prev_node): 79 users = list(prev_node.users.keys()) 80 inserted_node = graph_module.graph.create_node( 81 "call_function", 82 exir_ops.edge.llama.fallback.default, 83 (prev_node,), 84 ) 85 inserted_node.meta["val"] = prev_node.meta["val"] 86 if prev_node.meta.get(QCOM_QUANT_ATTRS, None): 87 inserted_node.meta[QCOM_QUANT_ATTRS] = prev_node.meta[ 88 QCOM_QUANT_ATTRS 89 ] 90 for user in users: 91 user.replace_input_with(prev_node, inserted_node) 92 93 prev_layer = cur_layer 94 prev_node = node 95 96 def call(self, graph_module: torch.fx.GraphModule): 97 self._insert_fallback_op(graph_module) 98 graph_module.recompile() 99 return PassResult(graph_module, True) 100 101 102def split_graph(edge_program: ExportedProgram, num_layers: int, shares: int): 103 graph_module = edge_program.graph_module 104 shard_layers = list(range(0, num_layers, int(num_layers / shares))) 105 return SplitGraph(shard_layers)(graph_module) 106