xref: /aosp_15_r20/external/pytorch/torch/_classes.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import types
3
4import torch._C
5
6
7class _ClassNamespace(types.ModuleType):
8    def __init__(self, name):
9        super().__init__("torch.classes" + name)
10        self.name = name
11
12    def __getattr__(self, attr):
13        proxy = torch._C._get_custom_class_python_wrapper(self.name, attr)
14        if proxy is None:
15            raise RuntimeError(f"Class {self.name}.{attr} not registered!")
16        return proxy
17
18
19class _Classes(types.ModuleType):
20    __file__ = "_classes.py"
21
22    def __init__(self) -> None:
23        super().__init__("torch.classes")
24
25    def __getattr__(self, name):
26        namespace = _ClassNamespace(name)
27        setattr(self, name, namespace)
28        return namespace
29
30    @property
31    def loaded_libraries(self):
32        return torch.ops.loaded_libraries
33
34    def load_library(self, path):
35        """
36        Loads a shared library from the given path into the current process.
37
38        The library being loaded may run global initialization code to register
39        custom classes with the PyTorch JIT runtime. This allows dynamically
40        loading custom classes. For this, you should compile your class
41        and the static registration code into a shared library object, and then
42        call ``torch.classes.load_library('path/to/libcustom.so')`` to load the
43        shared object.
44
45        After the library is loaded, it is added to the
46        ``torch.classes.loaded_libraries`` attribute, a set that may be inspected
47        for the paths of all libraries loaded using this function.
48
49        Args:
50            path (str): A path to a shared library to load.
51        """
52        torch.ops.load_library(path)
53
54
55# The classes "namespace"
56classes = _Classes()
57