xref: /aosp_15_r20/external/pytorch/torch/_library/simple_registry.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Callable, Optional
3
4from .fake_impl import FakeImplHolder
5from .utils import RegistrationHandle
6
7
8__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"]
9
10
11class SimpleLibraryRegistry:
12    """Registry for the "simple" torch.library APIs
13
14    The "simple" torch.library APIs are a higher-level API on top of the
15    raw PyTorch DispatchKey registration APIs that includes:
16    - fake impl
17
18    Registrations for these APIs do not go into the PyTorch dispatcher's
19    table because they may not directly involve a DispatchKey. For example,
20    the fake impl is a Python function that gets invoked by FakeTensor.
21    Instead, we manage them here.
22
23    SimpleLibraryRegistry is a mapping from a fully qualified operator name
24    (including the overload) to SimpleOperatorEntry.
25    """
26
27    def __init__(self):
28        self._data = {}
29
30    def find(self, qualname: str) -> "SimpleOperatorEntry":
31        if qualname not in self._data:
32            self._data[qualname] = SimpleOperatorEntry(qualname)
33        return self._data[qualname]
34
35
36singleton: SimpleLibraryRegistry = SimpleLibraryRegistry()
37
38
39class SimpleOperatorEntry:
40    """This is 1:1 to an operator overload.
41
42    The fields of SimpleOperatorEntry are Holders where kernels can be
43    registered to.
44    """
45
46    def __init__(self, qualname: str):
47        self.qualname: str = qualname
48        self.fake_impl: FakeImplHolder = FakeImplHolder(qualname)
49        self.torch_dispatch_rules: GenericTorchDispatchRuleHolder = (
50            GenericTorchDispatchRuleHolder(qualname)
51        )
52
53    # For compatibility reasons. We can delete this soon.
54    @property
55    def abstract_impl(self):
56        return self.fake_impl
57
58
59class GenericTorchDispatchRuleHolder:
60    def __init__(self, qualname):
61        self._data = {}
62        self.qualname = qualname
63
64    def register(
65        self, torch_dispatch_class: type, func: Callable
66    ) -> RegistrationHandle:
67        if self.find(torch_dispatch_class):
68            raise RuntimeError(
69                f"{torch_dispatch_class} already has a `__torch_dispatch__` rule registered for {self.qualname}"
70            )
71        self._data[torch_dispatch_class] = func
72
73        def deregister():
74            del self._data[torch_dispatch_class]
75
76        return RegistrationHandle(deregister)
77
78    def find(self, torch_dispatch_class):
79        return self._data.get(torch_dispatch_class, None)
80
81
82def find_torch_dispatch_rule(op, torch_dispatch_class: type) -> Optional[Callable]:
83    return singleton.find(op.__qualname__).torch_dispatch_rules.find(
84        torch_dispatch_class
85    )
86