#!/usr/bin/env python3 import argparse import os import sys from typing import Set # Note - hf and timm have their own version of this, torchbench does not # TOOD(voz): Someday, consolidate all the files into one runner instead of a shim like this... def model_names(filename: str) -> Set[str]: names = set() with open(filename) as fh: lines = fh.readlines() lines = [line.rstrip() for line in lines] for line in lines: line_parts = line.split(" ") if len(line_parts) == 1: line_parts = line.split(",") model_name = line_parts[0] names.add(model_name) return names TIMM_MODEL_NAMES = model_names( os.path.join(os.path.dirname(__file__), "timm_models_list.txt") ) HF_MODELS_FILE_NAME = model_names( os.path.join(os.path.dirname(__file__), "huggingface_models_list.txt") ) TORCHBENCH_MODELS_FILE_NAME = model_names( os.path.join(os.path.dirname(__file__), "all_torchbench_models_list.txt") ) # timm <> HF disjoint assert TIMM_MODEL_NAMES.isdisjoint(HF_MODELS_FILE_NAME) # timm <> torch disjoint assert TIMM_MODEL_NAMES.isdisjoint(TORCHBENCH_MODELS_FILE_NAME) # torch <> hf disjoint assert TORCHBENCH_MODELS_FILE_NAME.isdisjoint(HF_MODELS_FILE_NAME) def parse_args(args=None): parser = argparse.ArgumentParser() parser.add_argument( "--only", help="""Run just one model from whichever model suite it belongs to. Or specify the path and class name of the model in format like: --only=path:,class: Due to the fact that dynamo changes current working directory, the path should be an absolute path. The class should have a method get_example_inputs to return the inputs for the model. An example looks like ``` class LinearModel(nn.Module): def __init__(self): super().__init__() self.linear = nn.Linear(10, 10) def forward(self, x): return self.linear(x) def get_example_inputs(self): return (torch.randn(2, 10),) ``` """, ) return parser.parse_known_args(args) if __name__ == "__main__": args, unknown = parse_args() if args.only: name = args.only if name in TIMM_MODEL_NAMES: import timm_models timm_models.timm_main() elif name in HF_MODELS_FILE_NAME: import huggingface huggingface.huggingface_main() elif name in TORCHBENCH_MODELS_FILE_NAME: import torchbench torchbench.torchbench_main() else: print(f"Illegal model name? {name}") sys.exit(-1) else: import torchbench torchbench.torchbench_main() import huggingface huggingface.huggingface_main() import timm_models timm_models.timm_main()