1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from dataclasses import dataclass 8from pathlib import Path 9from typing import Any 10 11import gguf 12from gguf import GGUFValueType, ReaderTensor 13 14 15@dataclass 16class AttentionArgs: 17 head_count: int 18 head_count_kv: int 19 layer_norm_rms_epsilon: float 20 21 22@dataclass 23class RopeArgs: 24 freq_base: float 25 26 27@dataclass 28class GGUFModelArgs: 29 arch: str 30 embedding_length: int 31 block_count: int 32 feed_forward_length: int 33 vocab_size: int 34 attention: AttentionArgs 35 rope: RopeArgs 36 37 38@dataclass 39class GGUFWeights: 40 tensors: list[ReaderTensor] 41 42 43def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]: 44 metadata: dict[str, Any] = {} 45 46 for idx, field in enumerate(reader.fields.values()): 47 val = None 48 if field.types[:1] == [GGUFValueType.ARRAY]: 49 itype = field.types[-1] 50 if itype == GGUFValueType.STRING: 51 val = [ 52 str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data 53 ] 54 else: 55 val = [pv for idx in field.data for pv in field.parts[idx].tolist()] 56 elif field.types[0] == GGUFValueType.STRING: 57 val = str(bytes(field.parts[-1]), encoding="utf-8") 58 else: 59 val = field.parts[-1].tolist()[0] 60 61 metadata[field.name] = val 62 63 return metadata 64 65 66def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs: 67 arch = metadata["general.architecture"] 68 69 return GGUFModelArgs( 70 arch=arch, 71 embedding_length=metadata[f"{arch}.embedding_length"], 72 block_count=metadata[f"{arch}.block_count"], 73 feed_forward_length=metadata[f"{arch}.feed_forward_length"], 74 vocab_size=len(metadata["tokenizer.ggml.tokens"]), 75 attention=AttentionArgs( 76 head_count=metadata[f"{arch}.attention.head_count"], 77 head_count_kv=metadata[f"{arch}.attention.head_count_kv"], 78 layer_norm_rms_epsilon=metadata[f"{arch}.attention.layer_norm_rms_epsilon"], 79 ), 80 rope=RopeArgs( 81 # default value from llama2 model definition 82 freq_base=metadata.get(f"{arch}.rope.freq_base", 1e4), 83 ), 84 ) 85 86 87def load_file(gguf_file: str) -> (GGUFModelArgs, GGUFWeights): 88 """ 89 Load a GGUF file and return the model arguments and weights. 90 """ 91 if not Path(gguf_file).is_file(): 92 raise ValueError(f"Could not find file {gguf_file}") 93 94 reader = gguf.GGUFReader(gguf_file, "r") 95 96 # Step 1: Build GGUFModelArgs 97 metadata = _get_metadata(reader) 98 model_args = _build_model_args(metadata) 99 100 # Step 2: Build GGUFWeights 101 gguf_weights = GGUFWeights(tensors=reader.tensors) 102 103 return (model_args, gguf_weights) 104