xref: /aosp_15_r20/external/pytorch/torch/cuda/gds.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1import os
2import sys
3from typing import Callable, List, Optional
4
5import torch
6from torch.types import Storage
7
8
9__all__: List[str] = []
10
11
12def _dummy_fn(name: str) -> Callable:
13    def fn(*args, **kwargs):  # type: ignore[no-untyped-def]
14        raise RuntimeError(f"torch._C.{name} is not supported on this platform")
15
16    return fn
17
18
19if not hasattr(torch._C, "_gds_register_buffer"):
20    assert not hasattr(torch._C, "_gds_deregister_buffer")
21    assert not hasattr(torch._C, "_gds_register_handle")
22    assert not hasattr(torch._C, "_gds_deregister_handle")
23    assert not hasattr(torch._C, "_gds_load_storage")
24    assert not hasattr(torch._C, "_gds_save_storage")
25    # Define functions
26    torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
27    torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
28    torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
29    torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
30    torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
31    torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
32
33
34def _gds_register_buffer(s: Storage) -> None:
35    """Registers a buffer.
36
37    Args:
38        s (Storage): Buffer to register.
39    """
40    torch._C._gds_register_buffer(s)
41
42
43def _gds_deregister_buffer(s: Storage) -> None:
44    """Registers a buffer.
45
46    Args:
47        s (Storage): Buffer to register.
48    """
49    torch._C._gds_deregister_buffer(s)
50
51
52class _GdsFile:
53    r"""Wrapper around cuFile.
54
55    cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
56
57    Args:
58        filename (str): Name of the file to open.
59        flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
60            be added automatically.
61
62    .. _CUDA GPUDirect Storage Documentation:
63        https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
64    """
65
66    def __init__(self, filename: str, flags: int):
67        if sys.platform == "win32":
68            raise RuntimeError("GdsFile is not supported on this platform.")
69        self.filename = filename
70        self.flags = flags
71        self.fd = os.open(filename, flags | os.O_DIRECT)
72        self.handle: Optional[int] = None
73        self.register_handle()
74
75    def __del__(self) -> None:
76        if self.handle is not None:
77            self.deregister_handle()
78        os.close(self.fd)
79
80    def register_handle(self) -> None:
81        """Registers file descriptor to cuFile Driver.
82
83        This is a wrapper around ``cuFileHandleRegister``.
84        """
85        assert (
86            self.handle is None
87        ), "Cannot register a handle that is already registered."
88        self.handle = torch._C._gds_register_handle(self.fd)
89
90    def deregister_handle(self) -> None:
91        """Deregisters file descriptor from cuFile Driver.
92
93        This is a wrapper around ``cuFileHandleDeregister``.
94        """
95        assert (
96            self.handle is not None
97        ), "Cannot deregister a handle that is not registered."
98        torch._C._gds_deregister_handle(self.handle)
99        self.handle = None
100
101    def load_storage(self, storage: Storage, offset: int = 0) -> None:
102        """Loads data from the file into the storage.
103
104        This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
105        will be loaded from the file at ``offset`` into the storage.
106
107        Args:
108            storage (Storage): Storage to load data into.
109            offset (int, optional): Offset into the file to start loading from. (Default: 0)
110        """
111        assert (
112            self.handle is not None
113        ), "Cannot load data from a file that is not registered."
114        torch._C._gds_load_storage(self.handle, storage, offset)
115
116    def save_storage(self, storage: Storage, offset: int = 0) -> None:
117        """Saves data from the storage into the file.
118
119        This is a wrapper around ``cuFileWrite``. All bytes of the storage
120        will be written to the file at ``offset``.
121
122        Args:
123            storage (Storage): Storage to save data from.
124            offset (int, optional): Offset into the file to start saving to. (Default: 0)
125        """
126        assert (
127            self.handle is not None
128        ), "Cannot save data to a file that is not registered."
129        torch._C._gds_save_storage(self.handle, storage, offset)
130