xref: /aosp_15_r20/external/pytorch/torch/cuda/gds.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerimport os
2*da0073e9SAndroid Build Coastguard Workerimport sys
3*da0073e9SAndroid Build Coastguard Workerfrom typing import Callable, List, Optional
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch.types import Storage
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker__all__: List[str] = []
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerdef _dummy_fn(name: str) -> Callable:
13*da0073e9SAndroid Build Coastguard Worker    def fn(*args, **kwargs):  # type: ignore[no-untyped-def]
14*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(f"torch._C.{name} is not supported on this platform")
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker    return fn
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Workerif not hasattr(torch._C, "_gds_register_buffer"):
20*da0073e9SAndroid Build Coastguard Worker    assert not hasattr(torch._C, "_gds_deregister_buffer")
21*da0073e9SAndroid Build Coastguard Worker    assert not hasattr(torch._C, "_gds_register_handle")
22*da0073e9SAndroid Build Coastguard Worker    assert not hasattr(torch._C, "_gds_deregister_handle")
23*da0073e9SAndroid Build Coastguard Worker    assert not hasattr(torch._C, "_gds_load_storage")
24*da0073e9SAndroid Build Coastguard Worker    assert not hasattr(torch._C, "_gds_save_storage")
25*da0073e9SAndroid Build Coastguard Worker    # Define functions
26*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
27*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
28*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
29*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
30*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
31*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Workerdef _gds_register_buffer(s: Storage) -> None:
35*da0073e9SAndroid Build Coastguard Worker    """Registers a buffer.
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    Args:
38*da0073e9SAndroid Build Coastguard Worker        s (Storage): Buffer to register.
39*da0073e9SAndroid Build Coastguard Worker    """
40*da0073e9SAndroid Build Coastguard Worker    torch._C._gds_register_buffer(s)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Workerdef _gds_deregister_buffer(s: Storage) -> None:
44*da0073e9SAndroid Build Coastguard Worker    """Registers a buffer.
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    Args:
47*da0073e9SAndroid Build Coastguard Worker        s (Storage): Buffer to register.
48*da0073e9SAndroid Build Coastguard Worker    """
49*da0073e9SAndroid Build Coastguard Worker    torch._C._gds_deregister_buffer(s)
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Workerclass _GdsFile:
53*da0073e9SAndroid Build Coastguard Worker    r"""Wrapper around cuFile.
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker    cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker    Args:
58*da0073e9SAndroid Build Coastguard Worker        filename (str): Name of the file to open.
59*da0073e9SAndroid Build Coastguard Worker        flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
60*da0073e9SAndroid Build Coastguard Worker            be added automatically.
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    .. _CUDA GPUDirect Storage Documentation:
63*da0073e9SAndroid Build Coastguard Worker        https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
64*da0073e9SAndroid Build Coastguard Worker    """
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    def __init__(self, filename: str, flags: int):
67*da0073e9SAndroid Build Coastguard Worker        if sys.platform == "win32":
68*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError("GdsFile is not supported on this platform.")
69*da0073e9SAndroid Build Coastguard Worker        self.filename = filename
70*da0073e9SAndroid Build Coastguard Worker        self.flags = flags
71*da0073e9SAndroid Build Coastguard Worker        self.fd = os.open(filename, flags | os.O_DIRECT)
72*da0073e9SAndroid Build Coastguard Worker        self.handle: Optional[int] = None
73*da0073e9SAndroid Build Coastguard Worker        self.register_handle()
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def __del__(self) -> None:
76*da0073e9SAndroid Build Coastguard Worker        if self.handle is not None:
77*da0073e9SAndroid Build Coastguard Worker            self.deregister_handle()
78*da0073e9SAndroid Build Coastguard Worker        os.close(self.fd)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker    def register_handle(self) -> None:
81*da0073e9SAndroid Build Coastguard Worker        """Registers file descriptor to cuFile Driver.
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker        This is a wrapper around ``cuFileHandleRegister``.
84*da0073e9SAndroid Build Coastguard Worker        """
85*da0073e9SAndroid Build Coastguard Worker        assert (
86*da0073e9SAndroid Build Coastguard Worker            self.handle is None
87*da0073e9SAndroid Build Coastguard Worker        ), "Cannot register a handle that is already registered."
88*da0073e9SAndroid Build Coastguard Worker        self.handle = torch._C._gds_register_handle(self.fd)
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker    def deregister_handle(self) -> None:
91*da0073e9SAndroid Build Coastguard Worker        """Deregisters file descriptor from cuFile Driver.
92*da0073e9SAndroid Build Coastguard Worker
93*da0073e9SAndroid Build Coastguard Worker        This is a wrapper around ``cuFileHandleDeregister``.
94*da0073e9SAndroid Build Coastguard Worker        """
95*da0073e9SAndroid Build Coastguard Worker        assert (
96*da0073e9SAndroid Build Coastguard Worker            self.handle is not None
97*da0073e9SAndroid Build Coastguard Worker        ), "Cannot deregister a handle that is not registered."
98*da0073e9SAndroid Build Coastguard Worker        torch._C._gds_deregister_handle(self.handle)
99*da0073e9SAndroid Build Coastguard Worker        self.handle = None
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker    def load_storage(self, storage: Storage, offset: int = 0) -> None:
102*da0073e9SAndroid Build Coastguard Worker        """Loads data from the file into the storage.
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker        This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
105*da0073e9SAndroid Build Coastguard Worker        will be loaded from the file at ``offset`` into the storage.
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        Args:
108*da0073e9SAndroid Build Coastguard Worker            storage (Storage): Storage to load data into.
109*da0073e9SAndroid Build Coastguard Worker            offset (int, optional): Offset into the file to start loading from. (Default: 0)
110*da0073e9SAndroid Build Coastguard Worker        """
111*da0073e9SAndroid Build Coastguard Worker        assert (
112*da0073e9SAndroid Build Coastguard Worker            self.handle is not None
113*da0073e9SAndroid Build Coastguard Worker        ), "Cannot load data from a file that is not registered."
114*da0073e9SAndroid Build Coastguard Worker        torch._C._gds_load_storage(self.handle, storage, offset)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def save_storage(self, storage: Storage, offset: int = 0) -> None:
117*da0073e9SAndroid Build Coastguard Worker        """Saves data from the storage into the file.
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker        This is a wrapper around ``cuFileWrite``. All bytes of the storage
120*da0073e9SAndroid Build Coastguard Worker        will be written to the file at ``offset``.
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        Args:
123*da0073e9SAndroid Build Coastguard Worker            storage (Storage): Storage to save data from.
124*da0073e9SAndroid Build Coastguard Worker            offset (int, optional): Offset into the file to start saving to. (Default: 0)
125*da0073e9SAndroid Build Coastguard Worker        """
126*da0073e9SAndroid Build Coastguard Worker        assert (
127*da0073e9SAndroid Build Coastguard Worker            self.handle is not None
128*da0073e9SAndroid Build Coastguard Worker        ), "Cannot save data to a file that is not registered."
129*da0073e9SAndroid Build Coastguard Worker        torch._C._gds_save_storage(self.handle, storage, offset)
130