1import os 2import sys 3from typing import Callable, List, Optional 4 5import torch 6from torch.types import Storage 7 8 9__all__: List[str] = [] 10 11 12def _dummy_fn(name: str) -> Callable: 13 def fn(*args, **kwargs): # type: ignore[no-untyped-def] 14 raise RuntimeError(f"torch._C.{name} is not supported on this platform") 15 16 return fn 17 18 19if not hasattr(torch._C, "_gds_register_buffer"): 20 assert not hasattr(torch._C, "_gds_deregister_buffer") 21 assert not hasattr(torch._C, "_gds_register_handle") 22 assert not hasattr(torch._C, "_gds_deregister_handle") 23 assert not hasattr(torch._C, "_gds_load_storage") 24 assert not hasattr(torch._C, "_gds_save_storage") 25 # Define functions 26 torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer") 27 torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer") 28 torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle") 29 torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle") 30 torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage") 31 torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage") 32 33 34def _gds_register_buffer(s: Storage) -> None: 35 """Registers a buffer. 36 37 Args: 38 s (Storage): Buffer to register. 39 """ 40 torch._C._gds_register_buffer(s) 41 42 43def _gds_deregister_buffer(s: Storage) -> None: 44 """Registers a buffer. 45 46 Args: 47 s (Storage): Buffer to register. 48 """ 49 torch._C._gds_deregister_buffer(s) 50 51 52class _GdsFile: 53 r"""Wrapper around cuFile. 54 55 cuFile is a file-like interface to the GPUDirect Storage (GDS) API. 56 57 Args: 58 filename (str): Name of the file to open. 59 flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will 60 be added automatically. 61 62 .. _CUDA GPUDirect Storage Documentation: 63 https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api 64 """ 65 66 def __init__(self, filename: str, flags: int): 67 if sys.platform == "win32": 68 raise RuntimeError("GdsFile is not supported on this platform.") 69 self.filename = filename 70 self.flags = flags 71 self.fd = os.open(filename, flags | os.O_DIRECT) 72 self.handle: Optional[int] = None 73 self.register_handle() 74 75 def __del__(self) -> None: 76 if self.handle is not None: 77 self.deregister_handle() 78 os.close(self.fd) 79 80 def register_handle(self) -> None: 81 """Registers file descriptor to cuFile Driver. 82 83 This is a wrapper around ``cuFileHandleRegister``. 84 """ 85 assert ( 86 self.handle is None 87 ), "Cannot register a handle that is already registered." 88 self.handle = torch._C._gds_register_handle(self.fd) 89 90 def deregister_handle(self) -> None: 91 """Deregisters file descriptor from cuFile Driver. 92 93 This is a wrapper around ``cuFileHandleDeregister``. 94 """ 95 assert ( 96 self.handle is not None 97 ), "Cannot deregister a handle that is not registered." 98 torch._C._gds_deregister_handle(self.handle) 99 self.handle = None 100 101 def load_storage(self, storage: Storage, offset: int = 0) -> None: 102 """Loads data from the file into the storage. 103 104 This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data 105 will be loaded from the file at ``offset`` into the storage. 106 107 Args: 108 storage (Storage): Storage to load data into. 109 offset (int, optional): Offset into the file to start loading from. (Default: 0) 110 """ 111 assert ( 112 self.handle is not None 113 ), "Cannot load data from a file that is not registered." 114 torch._C._gds_load_storage(self.handle, storage, offset) 115 116 def save_storage(self, storage: Storage, offset: int = 0) -> None: 117 """Saves data from the storage into the file. 118 119 This is a wrapper around ``cuFileWrite``. All bytes of the storage 120 will be written to the file at ``offset``. 121 122 Args: 123 storage (Storage): Storage to save data from. 124 offset (int, optional): Offset into the file to start saving to. (Default: 0) 125 """ 126 assert ( 127 self.handle is not None 128 ), "Cannot save data to a file that is not registered." 129 torch._C._gds_save_storage(self.handle, storage, offset) 130