1# mypy: allow-untyped-defs 2from typing import Callable, Optional 3 4from .fake_impl import FakeImplHolder 5from .utils import RegistrationHandle 6 7 8__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] 9 10 11class SimpleLibraryRegistry: 12 """Registry for the "simple" torch.library APIs 13 14 The "simple" torch.library APIs are a higher-level API on top of the 15 raw PyTorch DispatchKey registration APIs that includes: 16 - fake impl 17 18 Registrations for these APIs do not go into the PyTorch dispatcher's 19 table because they may not directly involve a DispatchKey. For example, 20 the fake impl is a Python function that gets invoked by FakeTensor. 21 Instead, we manage them here. 22 23 SimpleLibraryRegistry is a mapping from a fully qualified operator name 24 (including the overload) to SimpleOperatorEntry. 25 """ 26 27 def __init__(self): 28 self._data = {} 29 30 def find(self, qualname: str) -> "SimpleOperatorEntry": 31 if qualname not in self._data: 32 self._data[qualname] = SimpleOperatorEntry(qualname) 33 return self._data[qualname] 34 35 36singleton: SimpleLibraryRegistry = SimpleLibraryRegistry() 37 38 39class SimpleOperatorEntry: 40 """This is 1:1 to an operator overload. 41 42 The fields of SimpleOperatorEntry are Holders where kernels can be 43 registered to. 44 """ 45 46 def __init__(self, qualname: str): 47 self.qualname: str = qualname 48 self.fake_impl: FakeImplHolder = FakeImplHolder(qualname) 49 self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = ( 50 GenericTorchDispatchRuleHolder(qualname) 51 ) 52 53 # For compatibility reasons. We can delete this soon. 54 @property 55 def abstract_impl(self): 56 return self.fake_impl 57 58 59class GenericTorchDispatchRuleHolder: 60 def __init__(self, qualname): 61 self._data = {} 62 self.qualname = qualname 63 64 def register( 65 self, torch_dispatch_class: type, func: Callable 66 ) -> RegistrationHandle: 67 if self.find(torch_dispatch_class): 68 raise RuntimeError( 69 f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}" 70 ) 71 self._data[torch_dispatch_class] = func 72 73 def deregister(): 74 del self._data[torch_dispatch_class] 75 76 return RegistrationHandle(deregister) 77 78 def find(self, torch_dispatch_class): 79 return self._data.get(torch_dispatch_class, None) 80 81 82def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]: 83 return singleton.find(op.__qualname__).torch_dispatch_rules.find( 84 torch_dispatch_class 85 ) 86