xref: /aosp_15_r20/external/pytorch/torch/_dynamo/device_interface.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import inspect
3from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
4
5import torch
6from torch._streambase import _EventBase, _StreamBase
7
8
9get_cuda_stream: Optional[Callable[[int], int]]
10if torch.cuda._is_compiled():
11    from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
12else:
13    get_cuda_stream = None
14
15_device_t = Union[torch.device, str, int, None]
16
17# Recording the device properties in the main process but used in worker process.
18caching_worker_device_properties: Dict[str, Any] = {}
19caching_worker_current_devices: Dict[str, int] = {}
20
21
22class DeviceInterfaceMeta(type):
23    def __new__(metacls, *args, **kwargs):
24        class_member = args[2]
25        if "Event" in class_member:
26            assert inspect.isclass(class_member["Event"]) and issubclass(
27                class_member["Event"], _EventBase
28            ), "DeviceInterface member Event should be inherit from _EventBase"
29        if "Stream" in class_member:
30            assert inspect.isclass(class_member["Stream"]) and issubclass(
31                class_member["Stream"], _StreamBase
32            ), "DeviceInterface member Stream should be inherit from _StreamBase"
33        return super().__new__(metacls, *args, **kwargs)
34
35
36class DeviceInterface(metaclass=DeviceInterfaceMeta):
37    """
38    This is a simple device runtime interface for Inductor. It enables custom
39    backends to be integrated with Inductor in a device-agnostic semantic.
40    """
41
42    class device:
43        def __new__(cls, device: _device_t):
44            raise NotImplementedError
45
46    class Worker:
47        """
48        Worker API to query device properties that will work in multi processing
49        workers that cannot use the GPU APIs (due to processing fork() and
50        initialization time issues). Properties are recorded in the main process
51        before we fork the workers.
52        """
53
54        @staticmethod
55        def set_device(device: int):
56            raise NotImplementedError
57
58        @staticmethod
59        def current_device() -> int:
60            raise NotImplementedError
61
62        @staticmethod
63        def get_device_properties(device: _device_t = None):
64            raise NotImplementedError
65
66    @staticmethod
67    def current_device():
68        raise NotImplementedError
69
70    @staticmethod
71    def set_device(device: _device_t):
72        raise NotImplementedError
73
74    @staticmethod
75    def maybe_exchange_device(device: int) -> int:
76        raise NotImplementedError
77
78    @staticmethod
79    def exchange_device(device: int) -> int:
80        raise NotImplementedError
81
82    @staticmethod
83    def device_count():
84        raise NotImplementedError
85
86    @staticmethod
87    def is_available() -> bool:
88        raise NotImplementedError
89
90    @staticmethod
91    def stream(stream: torch.Stream):
92        raise NotImplementedError
93
94    @staticmethod
95    def current_stream():
96        raise NotImplementedError
97
98    @staticmethod
99    def set_stream(stream: torch.Stream):
100        raise NotImplementedError
101
102    @staticmethod
103    def _set_stream_by_id(stream_id: int, device_index: int, device_type: int):
104        raise NotImplementedError
105
106    @staticmethod
107    def get_raw_stream(device_idx: int) -> int:
108        raise NotImplementedError
109
110    @staticmethod
111    def synchronize(device: _device_t = None):
112        raise NotImplementedError
113
114    @staticmethod
115    def get_device_properties(device: _device_t = None):
116        raise NotImplementedError
117
118    @staticmethod
119    def get_compute_capability(device: _device_t = None):
120        raise NotImplementedError
121
122    @staticmethod
123    def is_bf16_supported(including_emulation: bool = False):
124        raise NotImplementedError
125
126
127class DeviceGuard:
128    """
129    This class provides a context manager for device switching. This is a stripped
130    down version of torch.{device_name}.device.
131
132    The context manager changes the current device to the given device index
133    on entering the context and restores the original device on exiting.
134    The device is switched using the provided device interface.
135    """
136
137    def __init__(
138        self, device_interface: Type[DeviceInterface], index: Optional[int]
139    ) -> None:
140        self.device_interface = device_interface
141        self.idx = index
142        self.prev_idx = -1
143
144    def __enter__(self):
145        if self.idx is not None:
146            self.prev_idx = self.device_interface.exchange_device(self.idx)
147
148    def __exit__(self, type: Any, value: Any, traceback: Any):
149        if self.idx is not None:
150            self.idx = self.device_interface.maybe_exchange_device(self.prev_idx)
151        return False
152
153
154class CudaInterface(DeviceInterface):
155    device = torch.cuda.device
156
157    # register Event and Stream class into the backend interface
158    # make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase
159    Event = torch.cuda.Event
160    Stream = torch.cuda.Stream
161
162    class Worker:
163        @staticmethod
164        def set_device(device: int):
165            caching_worker_current_devices["cuda"] = device
166
167        @staticmethod
168        def current_device() -> int:
169            if "cuda" in caching_worker_current_devices:
170                return caching_worker_current_devices["cuda"]
171            return torch.cuda.current_device()
172
173        @staticmethod
174        def get_device_properties(device: _device_t = None):
175            if device is not None:
176                if isinstance(device, str):
177                    device = torch.device(device)
178                    assert device.type == "cuda"
179                if isinstance(device, torch.device):
180                    device = device.index
181            if device is None:
182                device = CudaInterface.Worker.current_device()
183
184            if "cuda" not in caching_worker_device_properties:
185                device_prop = [
186                    torch.cuda.get_device_properties(i)
187                    for i in range(torch.cuda.device_count())
188                ]
189                caching_worker_device_properties["cuda"] = device_prop
190
191            return caching_worker_device_properties["cuda"][device]
192
193    current_device = staticmethod(torch.cuda.current_device)
194    set_device = staticmethod(torch.cuda.set_device)
195    device_count = staticmethod(torch.cuda.device_count)
196    stream = staticmethod(torch.cuda.stream)  # type: ignore[assignment]
197    current_stream = staticmethod(torch.cuda.current_stream)
198    set_stream = staticmethod(torch.cuda.set_stream)  # type: ignore[assignment]
199    _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id)  # type: ignore[assignment]
200    synchronize = staticmethod(torch.cuda.synchronize)
201    get_device_properties = staticmethod(torch.cuda.get_device_properties)  # type: ignore[assignment]
202    get_raw_stream = staticmethod(get_cuda_stream)  # type: ignore[assignment, arg-type]
203    exchange_device = staticmethod(torch.cuda._exchange_device)  # type: ignore[arg-type]
204    maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device)  # type: ignore[arg-type]
205    is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported)  # type: ignore[arg-type]
206
207    # Can be mock patched by @patch decorator.
208    @staticmethod
209    def is_available() -> bool:
210        return torch.cuda.is_available()
211
212    @staticmethod
213    def get_compute_capability(device: _device_t = None):
214        if torch.version.hip is None:
215            major, min = torch.cuda.get_device_capability(device)
216            return major * 10 + min
217        else:
218            return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0]
219
220
221get_xpu_stream: Optional[Callable[[int], int]]
222if torch.xpu._is_compiled():
223    from torch._C import _xpu_getCurrentRawStream as get_xpu_stream
224else:
225    get_xpu_stream = None
226
227
228class XpuInterface(DeviceInterface):
229    device = torch.xpu.device
230    Event = torch.xpu.Event
231    Stream = torch.xpu.Stream
232
233    class Worker:
234        @staticmethod
235        def set_device(device: int):
236            caching_worker_current_devices["xpu"] = device
237
238        @staticmethod
239        def current_device() -> int:
240            if "xpu" in caching_worker_current_devices:
241                return caching_worker_current_devices["xpu"]
242            return torch.xpu.current_device()
243
244        @staticmethod
245        def get_device_properties(device: _device_t = None):
246            if device is not None:
247                if isinstance(device, str):
248                    device = torch.device(device)
249                    assert device.type == "xpu"
250                if isinstance(device, torch.device):
251                    device = device.index
252            if device is None:
253                device = XpuInterface.Worker.current_device()
254
255            if "xpu" not in caching_worker_device_properties:
256                device_prop = [
257                    torch.xpu.get_device_properties(i)
258                    for i in range(torch.xpu.device_count())
259                ]
260                caching_worker_device_properties["xpu"] = device_prop
261
262            return caching_worker_device_properties["xpu"][device]
263
264    current_device = staticmethod(torch.xpu.current_device)
265    set_device = staticmethod(torch.xpu.set_device)
266    device_count = staticmethod(torch.xpu.device_count)
267    stream = staticmethod(torch.xpu.stream)  # type: ignore[assignment]
268    current_stream = staticmethod(torch.xpu.current_stream)
269    set_stream = staticmethod(torch.xpu.set_stream)  # type: ignore[assignment]
270    _set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id)  # type: ignore[assignment]
271    synchronize = staticmethod(torch.xpu.synchronize)
272    get_device_properties = staticmethod(torch.xpu.get_device_properties)  # type: ignore[assignment]
273    get_raw_stream = staticmethod(get_xpu_stream)  # type: ignore[assignment, arg-type]
274    exchange_device = staticmethod(torch.xpu._exchange_device)  # type: ignore[arg-type]
275    maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device)  # type: ignore[arg-type]
276
277    # Can be mock patched by @patch decorator.
278    @staticmethod
279    def is_available() -> bool:
280        return torch.xpu.is_available()
281
282    @staticmethod
283    def get_compute_capability(device: _device_t = None):
284        cc = torch.xpu.get_device_capability(device)
285        return cc
286
287    @staticmethod
288    def is_bf16_supported(including_emulation: bool = False) -> bool:
289        return torch.xpu.is_bf16_supported()
290
291
292device_interfaces: Dict[str, Type[DeviceInterface]] = {}
293_device_initialized = False
294
295
296def register_interface_for_device(
297    device: Union[str, torch.device], device_interface: Type[DeviceInterface]
298):
299    if isinstance(device, torch.device):
300        device = str(device)
301    device_interfaces[device] = device_interface
302
303
304def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]:
305    if isinstance(device, torch.device):
306        device = str(device)
307    if not _device_initialized:
308        init_device_reg()
309    if device in device_interfaces:
310        return device_interfaces[device]
311    raise NotImplementedError(f"No interface for device {device}")
312
313
314def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]:
315    if not _device_initialized:
316        init_device_reg()
317    return device_interfaces.items()
318
319
320def init_device_reg():
321    global _device_initialized
322    register_interface_for_device("cuda", CudaInterface)
323    for i in range(torch.cuda.device_count()):
324        register_interface_for_device(f"cuda:{i}", CudaInterface)
325
326    register_interface_for_device("xpu", XpuInterface)
327    for i in range(torch.xpu.device_count()):
328        register_interface_for_device(f"xpu:{i}", XpuInterface)
329
330    _device_initialized = True
331