xref: /aosp_15_r20/external/pytorch/torch/_export/db/case.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3import re
4import string
5from dataclasses import dataclass, field
6from enum import Enum
7from typing import Any, Dict, List, Optional, Set, Tuple
8from types import ModuleType
9
10import torch
11
12_TAGS: Dict[str, Dict[str, Any]] = {
13    "torch": {
14        "cond": {},
15        "dynamic-shape": {},
16        "escape-hatch": {},
17        "map": {},
18        "dynamic-value": {},
19        "operator": {},
20        "mutation": {},
21    },
22    "python": {
23        "assert": {},
24        "builtin": {},
25        "closure": {},
26        "context-manager": {},
27        "control-flow": {},
28        "data-structure": {},
29        "standard-library": {},
30        "object-model": {},
31    },
32}
33
34
35class SupportLevel(Enum):
36    """
37    Indicates at what stage the feature
38    used in the example is handled in export.
39    """
40
41    SUPPORTED = 1
42    NOT_SUPPORTED_YET = 0
43
44
45ArgsType = Tuple[Any, ...]
46
47
48def check_inputs_type(args, kwargs):
49    if not isinstance(args, tuple):
50        raise ValueError(
51            f"Expecting args type to be a tuple, got: {type(args)}"
52        )
53    if not isinstance(kwargs, dict):
54        raise ValueError(
55            f"Expecting kwargs type to be a dict, got: {type(kwargs)}"
56        )
57    for key in kwargs:
58        if not isinstance(key, str):
59            raise ValueError(
60                f"Expecting kwargs keys to be a string, got: {type(key)}"
61            )
62
63def _validate_tag(tag: str):
64    parts = tag.split(".")
65    t = _TAGS
66    for part in parts:
67        assert set(part) <= set(
68            string.ascii_lowercase + "-"
69        ), f"Tag contains invalid characters: {part}"
70        if part in t:
71            t = t[part]
72        else:
73            raise ValueError(f"Tag {tag} is not found in registered tags.")
74
75
76@dataclass(frozen=True)
77class ExportCase:
78    example_args: ArgsType
79    description: str  # A description of the use case.
80    model: torch.nn.Module
81    name: str
82    example_kwargs: Dict[str, Any] = field(default_factory=dict)
83    extra_args: Optional[ArgsType] = None  # For testing graph generalization.
84    # Tags associated with the use case. (e.g dynamic-shape, escape-hatch)
85    tags: Set[str] = field(default_factory=set)
86    support_level: SupportLevel = SupportLevel.SUPPORTED
87    dynamic_shapes: Optional[Dict[str, Any]] = None
88
89    def __post_init__(self):
90        check_inputs_type(self.example_args, self.example_kwargs)
91        if self.extra_args is not None:
92            check_inputs_type(self.extra_args, {})
93
94        for tag in self.tags:
95            _validate_tag(tag)
96
97        if not isinstance(self.description, str) or len(self.description) == 0:
98            raise ValueError(f'Invalid description: "{self.description}"')
99
100
101_EXAMPLE_CASES: Dict[str, ExportCase] = {}
102_MODULES: Set[ModuleType] = set()
103_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {}
104_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {}
105
106
107def register_db_case(case: ExportCase) -> None:
108    """
109    Registers a user provided ExportCase into example bank.
110    """
111    if case.name in _EXAMPLE_CASES:
112        if case.name not in _EXAMPLE_CONFLICT_CASES:
113            _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]]
114        _EXAMPLE_CONFLICT_CASES[case.name].append(case)
115        return
116
117    _EXAMPLE_CASES[case.name] = case
118
119
120def to_snake_case(name):
121    name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
122    return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
123
124
125def _make_export_case(m, name, configs):
126    if not isinstance(m, torch.nn.Module):
127        raise TypeError("Export case class should be a torch.nn.Module.")
128
129    if "description" not in configs:
130        # Fallback to docstring if description is missing.
131        assert (
132            m.__doc__ is not None
133        ), f"Could not find description or docstring for export case: {m}"
134        configs = {**configs, "description": m.__doc__}
135    return ExportCase(**{**configs, "model": m, "name": name})
136
137
138def export_case(**kwargs):
139    """
140    Decorator for registering a user provided case into example bank.
141    """
142
143    def wrapper(m):
144        configs = kwargs
145        module = inspect.getmodule(m)
146        if module in _MODULES:
147            raise RuntimeError("export_case should only be used once per example file.")
148
149        assert module is not None
150        _MODULES.add(module)
151        module_name = module.__name__.split(".")[-1]
152        case = _make_export_case(m, module_name, configs)
153        register_db_case(case)
154        return case
155
156    return wrapper
157
158
159def export_rewrite_case(**kwargs):
160    def wrapper(m):
161        configs = kwargs
162
163        parent = configs.pop("parent")
164        assert isinstance(parent, ExportCase)
165        key = parent.name
166        if key not in _EXAMPLE_REWRITE_CASES:
167            _EXAMPLE_REWRITE_CASES[key] = []
168
169        configs["example_args"] = parent.example_args
170        case = _make_export_case(m, to_snake_case(m.__name__), configs)
171        _EXAMPLE_REWRITE_CASES[key].append(case)
172        return case
173
174    return wrapper
175