xref: /aosp_15_r20/external/executorch/extension/gguf_util/converters/llama_converter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8from typing import Any, Mapping
9
10import torch
11import torch.nn as nn
12from executorch.examples.models.llama.llama_transformer import (
13    ModelArgs as LlamaModelArgs,
14    Transformer as LlamaTransformer,
15)
16from executorch.extension.gguf_util.load_gguf import GGUFModelArgs, GGUFWeights
17
18
19def _create_pt_model(
20    gguf_model_args: GGUFModelArgs,
21) -> nn.Module:
22    llama_model_args = LlamaModelArgs(
23        dim=gguf_model_args.embedding_length,
24        n_layers=gguf_model_args.block_count,
25        n_heads=gguf_model_args.attention.head_count,
26        n_kv_heads=gguf_model_args.attention.head_count_kv,
27        vocab_size=gguf_model_args.vocab_size,
28        norm_eps=gguf_model_args.attention.layer_norm_rms_epsilon,
29        hidden_dim=gguf_model_args.feed_forward_length,
30        rope_freq_base=gguf_model_args.rope.freq_base,
31    )
32    pt_model = LlamaTransformer(llama_model_args)
33    pt_model.eval()
34    return pt_model
35
36
37_name_replacements = [
38    ("blk", "layers"),
39    ("token_embd", "tok_embeddings"),
40    ("attn_q", "attention.wq"),
41    ("attn_k", "attention.wk"),
42    ("attn_v", "attention.wv"),
43    ("attn_output", "attention.wo"),
44    ("attn_norm", "attention_norm"),
45    ("output_norm.weight", "norm.weight"),
46    ("ffn_down", "feed_forward.w2"),
47    ("ffn_gate", "feed_forward.w1"),
48    ("ffn_up", "feed_forward.w3"),
49]
50
51
52def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str:
53    result = copy.deepcopy(gguf_name)
54    for gguf_string, replacement in _name_replacements:
55        result = result.replace(gguf_string, replacement)
56    return result
57
58
59def _convert_to_state_dict(gguf_weights: GGUFWeights) -> Mapping[str, Any]:
60
61    state_dict = {}
62    for tensor in gguf_weights.tensors:
63        gguf_tensor_name = tensor.name
64        nn_tensor_name = _convert_gguf_tensor_name_to_llama_nn(gguf_tensor_name)
65        # gguf is reversed
66        reversed_shape = tensor.shape[::-1]
67        new_tensor = tensor.data.reshape(reversed_shape)
68        state_dict[nn_tensor_name] = torch.from_numpy(new_tensor)
69
70    return state_dict
71
72
73def _load_weights_into_nn(
74    pt_model: nn.Module, gguf_model_args: GGUFModelArgs, gguf_weights: GGUFWeights
75):
76
77    state_dict: Mapping[str, Any] = _convert_to_state_dict(gguf_weights)
78
79    # We need to fake initialize the mask, to match with the llama_transformer.py
80    for id in range(gguf_model_args.block_count):
81        mask_name = f"layers.{id}.attention.mask"
82        mask = torch.full(
83            (1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len),
84            float("-inf"),
85        )
86        mask = torch.triu(mask, diagonal=1)
87        state_dict[mask_name] = mask
88
89    pt_model.load_state_dict(state_dict)
90    return
91
92
93def _create_pte_program(pt_model: nn.Module) -> bytes:
94    # TODO (mnachin): Export
95    return
96
97
98def convert_to_pte(gguf_model_args: GGUFModelArgs, gguf_weights: GGUFWeights) -> bytes:
99    """Convert a GGUF model into an ExecuTorch program.
100
101    Args:
102        gguf_model_args: The arguments for the GGUF model.
103        gguf_weights: The weights of the GGUF model.
104    """
105
106    assert (
107        gguf_model_args.arch == "llama"
108    ), "Only LLaMa models are supported by this converter."
109
110    # Step 1: Create the PyTorch model
111    print("Create the PyTorch model")
112    pt_model = _create_pt_model(
113        gguf_model_args,
114    )
115
116    # Step 2: Load the weights into the PyTorch model
117    print("Load the weights into the PyTorch model")
118    _load_weights_into_nn(pt_model, gguf_model_args, gguf_weights)
119
120    # Step 3: Export to ExecuTorch
121    print("Exporting to ExecuTorch.")
122    pte_program = _create_pte_program(pt_model)
123    return pte_program
124