xref: /aosp_15_r20/external/executorch/extension/gguf_util/load_gguf.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from dataclasses import dataclass
8from pathlib import Path
9from typing import Any
10
11import gguf
12from gguf import GGUFValueType, ReaderTensor
13
14
15@dataclass
16class AttentionArgs:
17    head_count: int
18    head_count_kv: int
19    layer_norm_rms_epsilon: float
20
21
22@dataclass
23class RopeArgs:
24    freq_base: float
25
26
27@dataclass
28class GGUFModelArgs:
29    arch: str
30    embedding_length: int
31    block_count: int
32    feed_forward_length: int
33    vocab_size: int
34    attention: AttentionArgs
35    rope: RopeArgs
36
37
38@dataclass
39class GGUFWeights:
40    tensors: list[ReaderTensor]
41
42
43def _get_metadata(reader: gguf.GGUFReader) -> dict[str, Any]:
44    metadata: dict[str, Any] = {}
45
46    for idx, field in enumerate(reader.fields.values()):
47        val = None
48        if field.types[:1] == [GGUFValueType.ARRAY]:
49            itype = field.types[-1]
50            if itype == GGUFValueType.STRING:
51                val = [
52                    str(bytes(field.parts[idx]), encoding="utf-8") for idx in field.data
53                ]
54            else:
55                val = [pv for idx in field.data for pv in field.parts[idx].tolist()]
56        elif field.types[0] == GGUFValueType.STRING:
57            val = str(bytes(field.parts[-1]), encoding="utf-8")
58        else:
59            val = field.parts[-1].tolist()[0]
60
61        metadata[field.name] = val
62
63    return metadata
64
65
66def _build_model_args(metadata: dict[str, Any]) -> GGUFModelArgs:
67    arch = metadata["general.architecture"]
68
69    return GGUFModelArgs(
70        arch=arch,
71        embedding_length=metadata[f"{arch}.embedding_length"],
72        block_count=metadata[f"{arch}.block_count"],
73        feed_forward_length=metadata[f"{arch}.feed_forward_length"],
74        vocab_size=len(metadata["tokenizer.ggml.tokens"]),
75        attention=AttentionArgs(
76            head_count=metadata[f"{arch}.attention.head_count"],
77            head_count_kv=metadata[f"{arch}.attention.head_count_kv"],
78            layer_norm_rms_epsilon=metadata[f"{arch}.attention.layer_norm_rms_epsilon"],
79        ),
80        rope=RopeArgs(
81            # default value from llama2 model definition
82            freq_base=metadata.get(f"{arch}.rope.freq_base", 1e4),
83        ),
84    )
85
86
87def load_file(gguf_file: str) -> (GGUFModelArgs, GGUFWeights):
88    """
89    Load a GGUF file and return the model arguments and weights.
90    """
91    if not Path(gguf_file).is_file():
92        raise ValueError(f"Could not find file {gguf_file}")
93
94    reader = gguf.GGUFReader(gguf_file, "r")
95
96    # Step 1: Build GGUFModelArgs
97    metadata = _get_metadata(reader)
98    model_args = _build_model_args(metadata)
99
100    # Step 2: Build GGUFWeights
101    gguf_weights = GGUFWeights(tensors=reader.tensors)
102
103    return (model_args, gguf_weights)
104