1# mypy: allow-untyped-defs 2"""Module for handling symbolic function registration.""" 3 4import warnings 5from typing import ( 6 Callable, 7 Collection, 8 Dict, 9 Generic, 10 Optional, 11 Sequence, 12 Set, 13 TypeVar, 14 Union, 15) 16 17from torch.onnx import _constants, errors 18 19 20OpsetVersion = int 21 22 23def _dispatch_opset_version( 24 target: OpsetVersion, registered_opsets: Collection[OpsetVersion] 25) -> Optional[OpsetVersion]: 26 """Finds the registered opset given a target opset version and the available opsets. 27 28 Args: 29 target: The target opset version. 30 registered_opsets: The available opsets. 31 32 Returns: 33 The registered opset version. 34 """ 35 if not registered_opsets: 36 return None 37 38 descending_registered_versions = sorted(registered_opsets, reverse=True) 39 # Linear search for the opset version, which is fine since the number of opset 40 # versions is small. 41 42 if target >= _constants.ONNX_BASE_OPSET: 43 # Always look down toward opset 1 when the target is >= ONNX_BASE_OPSET (opset 9). 44 # When a custom op is register at opset 1, we want to be able to discover it as a 45 # fallback for all opsets >= ONNX_BASE_OPSET. 46 for version in descending_registered_versions: 47 if version <= target: 48 return version 49 return None 50 51 # target < opset 9. This is the legacy behavior to support opset 7 and opset 8. 52 # for caffe2 support. We search up toward opset 9. 53 for version in reversed(descending_registered_versions): 54 # Count back up until _constants.ONNX_BASE_OPSET 55 if target <= version <= _constants.ONNX_BASE_OPSET: 56 return version 57 58 return None 59 60 61_K = TypeVar("_K") 62_V = TypeVar("_V") 63 64 65class OverrideDict(Collection[_K], Generic[_K, _V]): 66 """A dictionary that merges built-in and custom symbolic functions. 67 68 It supports overriding and un-overriding built-in symbolic functions with custom 69 ones. 70 """ 71 72 def __init__(self) -> None: 73 self._base: Dict[_K, _V] = {} 74 self._overrides: Dict[_K, _V] = {} 75 self._merged: Dict[_K, _V] = {} 76 77 def set_base(self, key: _K, value: _V) -> None: 78 self._base[key] = value 79 if key not in self._overrides: 80 self._merged[key] = value 81 82 def in_base(self, key: _K) -> bool: 83 """Checks if a key is in the base dictionary.""" 84 return key in self._base 85 86 def override(self, key: _K, value: _V) -> None: 87 """Overrides a base key-value with a new pair.""" 88 self._overrides[key] = value 89 self._merged[key] = value 90 91 def remove_override(self, key: _K) -> None: 92 """Un-overrides a key-value pair.""" 93 self._overrides.pop(key, None) # type: ignore[arg-type] 94 self._merged.pop(key, None) # type: ignore[arg-type] 95 if key in self._base: 96 self._merged[key] = self._base[key] 97 98 def overridden(self, key: _K) -> bool: 99 """Checks if a key-value pair is overridden.""" 100 return key in self._overrides 101 102 def __getitem__(self, key: _K) -> _V: 103 return self._merged[key] 104 105 def get(self, key: _K, default: Optional[_V] = None): 106 return self._merged.get(key, default) 107 108 def __contains__(self, key: object) -> bool: 109 return key in self._merged 110 111 def __iter__(self): 112 return iter(self._merged) 113 114 def __len__(self) -> int: 115 return len(self._merged) 116 117 def __repr__(self) -> str: 118 return f"OverrideDict(base={self._base}, overrides={self._overrides})" 119 120 def __bool__(self) -> bool: 121 return bool(self._merged) 122 123 124class _SymbolicFunctionGroup: 125 """Different versions of symbolic functions registered to the same name. 126 127 O(number of registered versions of an op) search is performed to find the most 128 recent version of the op. 129 130 The registration is delayed until op is used to improve startup time. 131 132 Function overloads with different arguments are not allowed. 133 Custom op overrides are supported. 134 """ 135 136 def __init__(self, name: str) -> None: 137 self._name = name 138 # A dictionary of functions, keyed by the opset version. 139 self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict() 140 141 def __repr__(self) -> str: 142 return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})" 143 144 def __getitem__(self, key: OpsetVersion) -> Callable: 145 result = self.get(key) 146 if result is None: 147 raise KeyError(key) 148 return result 149 150 # TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes 151 # a problem. 152 def get(self, opset: OpsetVersion) -> Optional[Callable]: 153 """Find the most recent version of the function.""" 154 version = _dispatch_opset_version(opset, self._functions) 155 if version is None: 156 return None 157 158 return self._functions[version] 159 160 def add(self, func: Callable, opset: OpsetVersion) -> None: 161 """Adds a symbolic function. 162 163 Args: 164 func: The function to add. 165 opset: The opset version of the function to add. 166 """ 167 if self._functions.in_base(opset): 168 warnings.warn( 169 f"Symbolic function '{self._name}' already registered for opset {opset}. " 170 f"Replacing the existing function with new function. This is unexpected. " 171 f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.", 172 errors.OnnxExporterWarning, 173 ) 174 self._functions.set_base(opset, func) 175 176 def add_custom(self, func: Callable, opset: OpsetVersion) -> None: 177 """Adds a custom symbolic function. 178 179 Args: 180 func: The symbolic function to register. 181 opset: The corresponding opset version. 182 """ 183 self._functions.override(opset, func) 184 185 def remove_custom(self, opset: OpsetVersion) -> None: 186 """Removes a custom symbolic function. 187 188 Args: 189 opset: The opset version of the custom function to remove. 190 """ 191 if not self._functions.overridden(opset): 192 warnings.warn( 193 f"No custom function registered for '{self._name}' opset {opset}" 194 ) 195 return 196 self._functions.remove_override(opset) 197 198 def get_min_supported(self) -> OpsetVersion: 199 """Returns the lowest built-in opset version supported by the function.""" 200 return min(self._functions) 201 202 203class SymbolicRegistry: 204 """Registry for symbolic functions. 205 206 The registry maintains a mapping from qualified names to symbolic functions. 207 It is used to register new symbolic functions and to dispatch calls to 208 the appropriate function. 209 """ 210 211 def __init__(self) -> None: 212 self._registry: Dict[str, _SymbolicFunctionGroup] = {} 213 214 def register( 215 self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False 216 ) -> None: 217 """Registers a symbolic function. 218 219 Args: 220 name: The qualified name of the function to register. In the form of 'domain::op'. 221 E.g. 'aten::add'. 222 opset: The opset version of the function to register. 223 func: The symbolic function to register. 224 custom: Whether the function is a custom function that overrides existing ones. 225 226 Raises: 227 ValueError: If the separator '::' is not in the name. 228 """ 229 if "::" not in name: 230 raise ValueError( 231 f"The name must be in the form of 'domain::op', not '{name}'" 232 ) 233 symbolic_functions = self._registry.setdefault( 234 name, _SymbolicFunctionGroup(name) 235 ) 236 if custom: 237 symbolic_functions.add_custom(func, opset) 238 else: 239 symbolic_functions.add(func, opset) 240 241 def unregister(self, name: str, opset: OpsetVersion) -> None: 242 """Unregisters a symbolic function. 243 244 Args: 245 name: The qualified name of the function to unregister. 246 opset: The opset version of the function to unregister. 247 """ 248 if name not in self._registry: 249 return 250 self._registry[name].remove_custom(opset) 251 252 def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]: 253 """Returns the function group for the given name.""" 254 return self._registry.get(name) 255 256 def is_registered_op(self, name: str, version: int) -> bool: 257 """Returns whether the given op is registered for the given opset version.""" 258 functions = self.get_function_group(name) 259 if functions is None: 260 return False 261 return functions.get(version) is not None 262 263 def all_functions(self) -> Set[str]: 264 """Returns the set of all registered function names.""" 265 return set(self._registry) 266 267 268def onnx_symbolic( 269 name: str, 270 opset: Union[OpsetVersion, Sequence[OpsetVersion]], 271 decorate: Optional[Sequence[Callable]] = None, 272 custom: bool = False, 273) -> Callable: 274 """Registers a symbolic function. 275 276 Usage:: 277 278 ``` 279 @onnx_symbolic( 280 "aten::symbolic_b", 281 opset=10, 282 decorate=[quantized_aten_handler(scale=1 / 128, zero_point=0)], 283 ) 284 @symbolic_helper.parse_args("v", "v", "b") 285 def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: ... 286 ``` 287 288 Args: 289 name: The qualified name of the function in the form of 'domain::op'. 290 E.g. 'aten::add'. 291 opset: The opset versions of the function to register at. 292 decorate: A sequence of decorators to apply to the function. 293 custom: Whether the function is a custom symbolic function. 294 295 Raises: 296 ValueError: If the separator '::' is not in the name. 297 """ 298 299 def wrapper(func: Callable) -> Callable: 300 decorated = func 301 if decorate is not None: 302 for decorate_func in decorate: 303 decorated = decorate_func(decorated) 304 305 global registry 306 nonlocal opset 307 if isinstance(opset, OpsetVersion): 308 opset = (opset,) 309 for opset_version in opset: 310 registry.register(name, opset_version, decorated, custom=custom) 311 312 # Return the original function because the decorators in "decorate" are only 313 # specific to the instance being registered. 314 return func 315 316 return wrapper 317 318 319def custom_onnx_symbolic( 320 name: str, 321 opset: Union[OpsetVersion, Sequence[OpsetVersion]], 322 decorate: Optional[Sequence[Callable]] = None, 323) -> Callable: 324 """Registers a custom symbolic function. 325 326 Args: 327 name: the qualified name of the function. 328 opset: the opset version of the function. 329 decorate: a sequence of decorators to apply to the function. 330 331 Returns: 332 The decorator. 333 334 Raises: 335 ValueError: If the separator '::' is not in the name. 336 """ 337 return onnx_symbolic(name, opset, decorate, custom=True) 338 339 340# The registry for all symbolic functions. 341registry = SymbolicRegistry() 342