1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8from typing import Any, Mapping 9 10import torch 11import torch.nn as nn 12from executorch.examples.models.llama.llama_transformer import ( 13 ModelArgs as LlamaModelArgs, 14 Transformer as LlamaTransformer, 15) 16from executorch.extension.gguf_util.load_gguf import GGUFModelArgs, GGUFWeights 17 18 19def _create_pt_model( 20 gguf_model_args: GGUFModelArgs, 21) -> nn.Module: 22 llama_model_args = LlamaModelArgs( 23 dim=gguf_model_args.embedding_length, 24 n_layers=gguf_model_args.block_count, 25 n_heads=gguf_model_args.attention.head_count, 26 n_kv_heads=gguf_model_args.attention.head_count_kv, 27 vocab_size=gguf_model_args.vocab_size, 28 norm_eps=gguf_model_args.attention.layer_norm_rms_epsilon, 29 hidden_dim=gguf_model_args.feed_forward_length, 30 rope_freq_base=gguf_model_args.rope.freq_base, 31 ) 32 pt_model = LlamaTransformer(llama_model_args) 33 pt_model.eval() 34 return pt_model 35 36 37_name_replacements = [ 38 ("blk", "layers"), 39 ("token_embd", "tok_embeddings"), 40 ("attn_q", "attention.wq"), 41 ("attn_k", "attention.wk"), 42 ("attn_v", "attention.wv"), 43 ("attn_output", "attention.wo"), 44 ("attn_norm", "attention_norm"), 45 ("output_norm.weight", "norm.weight"), 46 ("ffn_down", "feed_forward.w2"), 47 ("ffn_gate", "feed_forward.w1"), 48 ("ffn_up", "feed_forward.w3"), 49] 50 51 52def _convert_gguf_tensor_name_to_llama_nn(gguf_name: str) -> str: 53 result = copy.deepcopy(gguf_name) 54 for gguf_string, replacement in _name_replacements: 55 result = result.replace(gguf_string, replacement) 56 return result 57 58 59def _convert_to_state_dict(gguf_weights: GGUFWeights) -> Mapping[str, Any]: 60 61 state_dict = {} 62 for tensor in gguf_weights.tensors: 63 gguf_tensor_name = tensor.name 64 nn_tensor_name = _convert_gguf_tensor_name_to_llama_nn(gguf_tensor_name) 65 # gguf is reversed 66 reversed_shape = tensor.shape[::-1] 67 new_tensor = tensor.data.reshape(reversed_shape) 68 state_dict[nn_tensor_name] = torch.from_numpy(new_tensor) 69 70 return state_dict 71 72 73def _load_weights_into_nn( 74 pt_model: nn.Module, gguf_model_args: GGUFModelArgs, gguf_weights: GGUFWeights 75): 76 77 state_dict: Mapping[str, Any] = _convert_to_state_dict(gguf_weights) 78 79 # We need to fake initialize the mask, to match with the llama_transformer.py 80 for id in range(gguf_model_args.block_count): 81 mask_name = f"layers.{id}.attention.mask" 82 mask = torch.full( 83 (1, 1, pt_model.params.max_seq_len, pt_model.params.max_seq_len), 84 float("-inf"), 85 ) 86 mask = torch.triu(mask, diagonal=1) 87 state_dict[mask_name] = mask 88 89 pt_model.load_state_dict(state_dict) 90 return 91 92 93def _create_pte_program(pt_model: nn.Module) -> bytes: 94 # TODO (mnachin): Export 95 return 96 97 98def convert_to_pte(gguf_model_args: GGUFModelArgs, gguf_weights: GGUFWeights) -> bytes: 99 """Convert a GGUF model into an ExecuTorch program. 100 101 Args: 102 gguf_model_args: The arguments for the GGUF model. 103 gguf_weights: The weights of the GGUF model. 104 """ 105 106 assert ( 107 gguf_model_args.arch == "llama" 108 ), "Only LLaMa models are supported by this converter." 109 110 # Step 1: Create the PyTorch model 111 print("Create the PyTorch model") 112 pt_model = _create_pt_model( 113 gguf_model_args, 114 ) 115 116 # Step 2: Load the weights into the PyTorch model 117 print("Load the weights into the PyTorch model") 118 _load_weights_into_nn(pt_model, gguf_model_args, gguf_weights) 119 120 # Step 3: Export to ExecuTorch 121 print("Exporting to ExecuTorch.") 122 pte_program = _create_pte_program(pt_model) 123 return pte_program 124