xref: /aosp_15_r20/external/executorch/examples/models/checkpoint.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9from pathlib import Path
10from typing import Any, Dict, Optional
11
12
13def get_default_model_resource_dir(model_file_path: str) -> Path:
14    """
15    Get the default path to resouce files (which contain files such as the
16    checkpoint and param files), either:
17    1. Uses the path from pkg_resources, only works with buck2
18    2. Uses default path located in examples/models/llama/params
19
20    Expected to be called from with a `model.py` file located in a
21    `executorch/examples/models/<model_name>` directory.
22
23    Args:
24        model_file_path: The file path to the eager model definition.
25            For example, `executorch/examples/models/llama/model.py`,
26            where `executorch/examples/models/llama` contains all
27            the llama2-related files.
28
29    Returns:
30        The path to the resource directory containing checkpoint, params, etc.
31    """
32
33    try:
34        import pkg_resources
35
36        # 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources.
37        # pyre-ignore
38        from executorch.examples.models.llama import params  # noqa
39
40        # Get the model name from the cwd, assuming that this module is called from a path such as
41        # examples/models/<model_name>/model.py.
42        model_name = Path(model_file_path).parent.name
43        resource_dir = Path(
44            pkg_resources.resource_filename(
45                f"executorch.examples.models.{model_name}", "params"
46            )
47        )
48    except:
49        # 2nd way.
50        resource_dir = Path(model_file_path).absolute().parent / "params"
51
52    return resource_dir
53
54
55def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]:
56    """
57    Get the dtype of the checkpoint, returning "None" if the checkpoint is empty.
58    """
59    dtype = None
60    if len(checkpoint) > 0:
61        first_key = next(iter(checkpoint))
62        first = checkpoint[first_key]
63        dtype = first.dtype
64        mismatched_dtypes = [
65            (key, value.dtype)
66            for key, value in checkpoint.items()
67            if value.dtype != dtype
68        ]
69        if len(mismatched_dtypes) > 0:
70            print(
71                f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}"
72            )
73    return dtype
74