1# mypy: allow-untyped-defs 2import types 3 4import torch._C 5 6 7class _ClassNamespace(types.ModuleType): 8 def __init__(self, name): 9 super().__init__("torch.classes" + name) 10 self.name = name 11 12 def __getattr__(self, attr): 13 proxy = torch._C._get_custom_class_python_wrapper(self.name, attr) 14 if proxy is None: 15 raise RuntimeError(f"Class {self.name}.{attr} not registered!") 16 return proxy 17 18 19class _Classes(types.ModuleType): 20 __file__ = "_classes.py" 21 22 def __init__(self) -> None: 23 super().__init__("torch.classes") 24 25 def __getattr__(self, name): 26 namespace = _ClassNamespace(name) 27 setattr(self, name, namespace) 28 return namespace 29 30 @property 31 def loaded_libraries(self): 32 return torch.ops.loaded_libraries 33 34 def load_library(self, path): 35 """ 36 Loads a shared library from the given path into the current process. 37 38 The library being loaded may run global initialization code to register 39 custom classes with the PyTorch JIT runtime. This allows dynamically 40 loading custom classes. For this, you should compile your class 41 and the static registration code into a shared library object, and then 42 call ``torch.classes.load_library('path/to/libcustom.so')`` to load the 43 shared object. 44 45 After the library is loaded, it is added to the 46 ``torch.classes.loaded_libraries`` attribute, a set that may be inspected 47 for the paths of all libraries loaded using this function. 48 49 Args: 50 path (str): A path to a shared library to load. 51 """ 52 torch.ops.load_library(path) 53 54 55# The classes "namespace" 56classes = _Classes() 57