1# mypy: allow-untyped-defs 2import dataclasses 3import glob 4import inspect 5from os.path import basename, dirname, isfile, join 6 7import torch 8from torch._export.db.case import ( 9 _EXAMPLE_CASES, 10 _EXAMPLE_CONFLICT_CASES, 11 _EXAMPLE_REWRITE_CASES, 12 SupportLevel, 13 export_case, 14 ExportCase, 15) 16 17 18def _collect_examples(): 19 case_names = glob.glob(join(dirname(__file__), "*.py")) 20 case_names = [ 21 basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py") 22 ] 23 24 case_fields = {f.name for f in dataclasses.fields(ExportCase)} 25 for case_name in case_names: 26 case = __import__(case_name, globals(), locals(), [], 1) 27 variables = [name for name in dir(case) if name in case_fields] 28 export_case(**{v: getattr(case, v) for v in variables})(case.model) 29 30_collect_examples() 31 32def all_examples(): 33 return _EXAMPLE_CASES 34 35 36if len(_EXAMPLE_CONFLICT_CASES) > 0: 37 38 def get_name(case): 39 model = case.model 40 if isinstance(model, torch.nn.Module): 41 model = type(model) 42 return model.__name__ 43 44 msg = "Error on conflict export case name.\n" 45 for case_name, cases in _EXAMPLE_CONFLICT_CASES.items(): 46 msg += f"Case name {case_name} is associated with multiple cases:\n " 47 msg += f"[{','.join(map(get_name, cases))}]\n" 48 49 raise RuntimeError(msg) 50 51 52def filter_examples_by_support_level(support_level: SupportLevel): 53 return { 54 key: val 55 for key, val in all_examples().items() 56 if val.support_level == support_level 57 } 58 59 60def get_rewrite_cases(case): 61 return _EXAMPLE_REWRITE_CASES.get(case.name, []) 62