xref: /aosp_15_r20/external/executorch/extension/gguf_util/converter.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1*523fa7a6SAndroid Build Coastguard Worker# Copyright (c) Meta Platforms, Inc. and affiliates.
2*523fa7a6SAndroid Build Coastguard Worker# All rights reserved.
3*523fa7a6SAndroid Build Coastguard Worker#
4*523fa7a6SAndroid Build Coastguard Worker# This source code is licensed under the BSD-style license found in the
5*523fa7a6SAndroid Build Coastguard Worker# LICENSE file in the root directory of this source tree.
6*523fa7a6SAndroid Build Coastguard Worker
7*523fa7a6SAndroid Build Coastguard Workerfrom executorch.extension.gguf_util.load_gguf import GGUFModelArgs, GGUFWeights
8*523fa7a6SAndroid Build Coastguard Worker
9*523fa7a6SAndroid Build Coastguard Worker
10*523fa7a6SAndroid Build Coastguard Workerdef convert_to_pte(model_args: GGUFModelArgs, weights: GGUFWeights) -> None:
11*523fa7a6SAndroid Build Coastguard Worker    """Convert a GGUF model into a PTE file, an ExecuTorch program.
12*523fa7a6SAndroid Build Coastguard Worker
13*523fa7a6SAndroid Build Coastguard Worker    Args:
14*523fa7a6SAndroid Build Coastguard Worker        model_args: The arguments for the GGUF model.
15*523fa7a6SAndroid Build Coastguard Worker        weights: The weights of the GGUF model.
16*523fa7a6SAndroid Build Coastguard Worker    """
17*523fa7a6SAndroid Build Coastguard Worker
18*523fa7a6SAndroid Build Coastguard Worker    # Switch statement based on the architecture enum.
19*523fa7a6SAndroid Build Coastguard Worker    # Each enum has its own converter function.
20*523fa7a6SAndroid Build Coastguard Worker    if model_args.arch == "llama":
21*523fa7a6SAndroid Build Coastguard Worker        from executorch.extension.gguf_util.converters.llama_converter import (
22*523fa7a6SAndroid Build Coastguard Worker            convert_to_pte as llama_convert_to_pte,
23*523fa7a6SAndroid Build Coastguard Worker        )
24*523fa7a6SAndroid Build Coastguard Worker
25*523fa7a6SAndroid Build Coastguard Worker        return llama_convert_to_pte(model_args, weights)
26*523fa7a6SAndroid Build Coastguard Worker    else:
27*523fa7a6SAndroid Build Coastguard Worker        raise NotImplementedError("Unsupported architecture.")
28