1# mypy: allow-untyped-defs 2import inspect 3import re 4import string 5from dataclasses import dataclass, field 6from enum import Enum 7from typing import Any, Dict, List, Optional, Set, Tuple 8from types import ModuleType 9 10import torch 11 12_TAGS: Dict[str, Dict[str, Any]] = { 13 "torch": { 14 "cond": {}, 15 "dynamic-shape": {}, 16 "escape-hatch": {}, 17 "map": {}, 18 "dynamic-value": {}, 19 "operator": {}, 20 "mutation": {}, 21 }, 22 "python": { 23 "assert": {}, 24 "builtin": {}, 25 "closure": {}, 26 "context-manager": {}, 27 "control-flow": {}, 28 "data-structure": {}, 29 "standard-library": {}, 30 "object-model": {}, 31 }, 32} 33 34 35class SupportLevel(Enum): 36 """ 37 Indicates at what stage the feature 38 used in the example is handled in export. 39 """ 40 41 SUPPORTED = 1 42 NOT_SUPPORTED_YET = 0 43 44 45ArgsType = Tuple[Any, ...] 46 47 48def check_inputs_type(args, kwargs): 49 if not isinstance(args, tuple): 50 raise ValueError( 51 f"Expecting args type to be a tuple, got: {type(args)}" 52 ) 53 if not isinstance(kwargs, dict): 54 raise ValueError( 55 f"Expecting kwargs type to be a dict, got: {type(kwargs)}" 56 ) 57 for key in kwargs: 58 if not isinstance(key, str): 59 raise ValueError( 60 f"Expecting kwargs keys to be a string, got: {type(key)}" 61 ) 62 63def _validate_tag(tag: str): 64 parts = tag.split(".") 65 t = _TAGS 66 for part in parts: 67 assert set(part) <= set( 68 string.ascii_lowercase + "-" 69 ), f"Tag contains invalid characters: {part}" 70 if part in t: 71 t = t[part] 72 else: 73 raise ValueError(f"Tag {tag} is not found in registered tags.") 74 75 76@dataclass(frozen=True) 77class ExportCase: 78 example_args: ArgsType 79 description: str # A description of the use case. 80 model: torch.nn.Module 81 name: str 82 example_kwargs: Dict[str, Any] = field(default_factory=dict) 83 extra_args: Optional[ArgsType] = None # For testing graph generalization. 84 # Tags associated with the use case. (e.g dynamic-shape, escape-hatch) 85 tags: Set[str] = field(default_factory=set) 86 support_level: SupportLevel = SupportLevel.SUPPORTED 87 dynamic_shapes: Optional[Dict[str, Any]] = None 88 89 def __post_init__(self): 90 check_inputs_type(self.example_args, self.example_kwargs) 91 if self.extra_args is not None: 92 check_inputs_type(self.extra_args, {}) 93 94 for tag in self.tags: 95 _validate_tag(tag) 96 97 if not isinstance(self.description, str) or len(self.description) == 0: 98 raise ValueError(f'Invalid description: "{self.description}"') 99 100 101_EXAMPLE_CASES: Dict[str, ExportCase] = {} 102_MODULES: Set[ModuleType] = set() 103_EXAMPLE_CONFLICT_CASES: Dict[str, List[ExportCase]] = {} 104_EXAMPLE_REWRITE_CASES: Dict[str, List[ExportCase]] = {} 105 106 107def register_db_case(case: ExportCase) -> None: 108 """ 109 Registers a user provided ExportCase into example bank. 110 """ 111 if case.name in _EXAMPLE_CASES: 112 if case.name not in _EXAMPLE_CONFLICT_CASES: 113 _EXAMPLE_CONFLICT_CASES[case.name] = [_EXAMPLE_CASES[case.name]] 114 _EXAMPLE_CONFLICT_CASES[case.name].append(case) 115 return 116 117 _EXAMPLE_CASES[case.name] = case 118 119 120def to_snake_case(name): 121 name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) 122 return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() 123 124 125def _make_export_case(m, name, configs): 126 if not isinstance(m, torch.nn.Module): 127 raise TypeError("Export case class should be a torch.nn.Module.") 128 129 if "description" not in configs: 130 # Fallback to docstring if description is missing. 131 assert ( 132 m.__doc__ is not None 133 ), f"Could not find description or docstring for export case: {m}" 134 configs = {**configs, "description": m.__doc__} 135 return ExportCase(**{**configs, "model": m, "name": name}) 136 137 138def export_case(**kwargs): 139 """ 140 Decorator for registering a user provided case into example bank. 141 """ 142 143 def wrapper(m): 144 configs = kwargs 145 module = inspect.getmodule(m) 146 if module in _MODULES: 147 raise RuntimeError("export_case should only be used once per example file.") 148 149 assert module is not None 150 _MODULES.add(module) 151 module_name = module.__name__.split(".")[-1] 152 case = _make_export_case(m, module_name, configs) 153 register_db_case(case) 154 return case 155 156 return wrapper 157 158 159def export_rewrite_case(**kwargs): 160 def wrapper(m): 161 configs = kwargs 162 163 parent = configs.pop("parent") 164 assert isinstance(parent, ExportCase) 165 key = parent.name 166 if key not in _EXAMPLE_REWRITE_CASES: 167 _EXAMPLE_REWRITE_CASES[key] = [] 168 169 configs["example_args"] = parent.example_args 170 case = _make_export_case(m, to_snake_case(m.__name__), configs) 171 _EXAMPLE_REWRITE_CASES[key].append(case) 172 return case 173 174 return wrapper 175