xref: /aosp_15_r20/external/pytorch/torch/_export/db/examples/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import dataclasses
3import glob
4import inspect
5from os.path import basename, dirname, isfile, join
6
7import torch
8from torch._export.db.case import (
9    _EXAMPLE_CASES,
10    _EXAMPLE_CONFLICT_CASES,
11    _EXAMPLE_REWRITE_CASES,
12    SupportLevel,
13    export_case,
14    ExportCase,
15)
16
17
18def _collect_examples():
19    case_names = glob.glob(join(dirname(__file__), "*.py"))
20    case_names = [
21        basename(f)[:-3] for f in case_names if isfile(f) and not f.endswith("__init__.py")
22    ]
23
24    case_fields = {f.name for f in dataclasses.fields(ExportCase)}
25    for case_name in case_names:
26        case = __import__(case_name, globals(), locals(), [], 1)
27        variables = [name for name in dir(case) if name in case_fields]
28        export_case(**{v: getattr(case, v) for v in variables})(case.model)
29
30_collect_examples()
31
32def all_examples():
33    return _EXAMPLE_CASES
34
35
36if len(_EXAMPLE_CONFLICT_CASES) > 0:
37
38    def get_name(case):
39        model = case.model
40        if isinstance(model, torch.nn.Module):
41            model = type(model)
42        return model.__name__
43
44    msg = "Error on conflict export case name.\n"
45    for case_name, cases in _EXAMPLE_CONFLICT_CASES.items():
46        msg += f"Case name {case_name} is associated with multiple cases:\n  "
47        msg += f"[{','.join(map(get_name, cases))}]\n"
48
49    raise RuntimeError(msg)
50
51
52def filter_examples_by_support_level(support_level: SupportLevel):
53    return {
54        key: val
55        for key, val in all_examples().items()
56        if val.support_level == support_level
57    }
58
59
60def get_rewrite_cases(case):
61    return _EXAMPLE_REWRITE_CASES.get(case.name, [])
62