1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9from pathlib import Path 10from typing import Any, Dict, Optional 11 12 13def get_default_model_resource_dir(model_file_path: str) -> Path: 14 """ 15 Get the default path to resouce files (which contain files such as the 16 checkpoint and param files), either: 17 1. Uses the path from pkg_resources, only works with buck2 18 2. Uses default path located in examples/models/llama/params 19 20 Expected to be called from with a `model.py` file located in a 21 `executorch/examples/models/<model_name>` directory. 22 23 Args: 24 model_file_path: The file path to the eager model definition. 25 For example, `executorch/examples/models/llama/model.py`, 26 where `executorch/examples/models/llama` contains all 27 the llama2-related files. 28 29 Returns: 30 The path to the resource directory containing checkpoint, params, etc. 31 """ 32 33 try: 34 import pkg_resources 35 36 # 1st way: If we can import this path, we are running with buck2 and all resources can be accessed with pkg_resources. 37 # pyre-ignore 38 from executorch.examples.models.llama import params # noqa 39 40 # Get the model name from the cwd, assuming that this module is called from a path such as 41 # examples/models/<model_name>/model.py. 42 model_name = Path(model_file_path).parent.name 43 resource_dir = Path( 44 pkg_resources.resource_filename( 45 f"executorch.examples.models.{model_name}", "params" 46 ) 47 ) 48 except: 49 # 2nd way. 50 resource_dir = Path(model_file_path).absolute().parent / "params" 51 52 return resource_dir 53 54 55def get_checkpoint_dtype(checkpoint: Dict[str, Any]) -> Optional[str]: 56 """ 57 Get the dtype of the checkpoint, returning "None" if the checkpoint is empty. 58 """ 59 dtype = None 60 if len(checkpoint) > 0: 61 first_key = next(iter(checkpoint)) 62 first = checkpoint[first_key] 63 dtype = first.dtype 64 mismatched_dtypes = [ 65 (key, value.dtype) 66 for key, value in checkpoint.items() 67 if value.dtype != dtype 68 ] 69 if len(mismatched_dtypes) > 0: 70 print( 71 f"Mixed dtype model. Dtype of {first_key}: {first.dtype}. Mismatches in the checkpoint: {mismatched_dtypes}" 72 ) 73 return dtype 74