xref: /aosp_15_r20/external/pytorch/torch/xpu/__init__.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2r"""
3This package introduces support for the XPU backend, specifically tailored for
4Intel GPU optimization.
5
6This package is lazily initialized, so you can always import it, and use
7:func:`is_available()` to determine if your system supports XPU.
8"""
9import threading
10import traceback
11from functools import lru_cache
12from typing import Any, Callable, Dict, List, Optional, Tuple, Union
13
14import torch
15import torch._C
16from torch import device as _device
17from torch._utils import _dummy_type, _LazySeedTracker
18
19from ._utils import _get_device_index
20from .streams import Event, Stream
21
22
23_initialized = False
24_tls = threading.local()
25_initialization_lock = threading.Lock()
26_queued_calls: List[
27    Tuple[Callable[[], None], List[str]]
28] = []  # don't invoke these until initialization occurs
29_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
30_device_t = Union[_device, str, int, None]
31_lazy_seed_tracker = _LazySeedTracker()
32default_generators: Tuple[torch._C.Generator] = ()  # type: ignore[assignment]
33
34
35def _is_compiled() -> bool:
36    r"""Return true if compile with XPU support."""
37    return torch._C._has_xpu
38
39
40if _is_compiled():
41    _XpuDeviceProperties = torch._C._XpuDeviceProperties
42    _exchange_device = torch._C._xpu_exchangeDevice
43    _maybe_exchange_device = torch._C._xpu_maybeExchangeDevice
44else:
45    # Define dummy if PyTorch was compiled without XPU
46    _XpuDeviceProperties = _dummy_type("_XpuDeviceProperties")  # type: ignore[assignment, misc]
47
48    def _exchange_device(device: int) -> int:
49        raise NotImplementedError("PyTorch was compiled without XPU support")
50
51    def _maybe_exchange_device(device: int) -> int:
52        raise NotImplementedError("PyTorch was compiled without XPU support")
53
54
55@lru_cache(maxsize=1)
56def device_count() -> int:
57    r"""Return the number of XPU device available."""
58    if not _is_compiled():
59        return 0
60    return torch._C._xpu_getDeviceCount()
61
62
63def is_available() -> bool:
64    r"""Return a bool indicating if XPU is currently available."""
65    # This function nerver throws.
66    return device_count() > 0
67
68
69def is_bf16_supported():
70    r"""Return a bool indicating if the current XPU device supports dtype bfloat16."""
71    return True
72
73
74def is_initialized():
75    r"""Return whether PyTorch's XPU state has been initialized."""
76    return _initialized and not _is_in_bad_fork()
77
78
79def _lazy_call(callable, **kwargs):
80    if is_initialized():
81        callable()
82    else:
83        global _lazy_seed_tracker
84        if kwargs.get("seed_all", False):
85            _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack())
86        elif kwargs.get("seed", False):
87            _lazy_seed_tracker.queue_seed(callable, traceback.format_stack())
88        else:
89            # Don't store the actual traceback to avoid memory cycle
90            _queued_calls.append((callable, traceback.format_stack()))
91
92
93def init():
94    r"""Initialize PyTorch's XPU state.
95    This is a Python API about lazy initialization that avoids initializing
96    XPU until the first time it is accessed. Does nothing if the XPU state is
97    already initialized.
98    """
99    _lazy_init()
100
101
102def _lazy_init():
103    global _initialized, _queued_calls
104    if is_initialized() or hasattr(_tls, "is_initializing"):
105        return
106    with _initialization_lock:
107        # This test was was protected via GIL. Double-check whether XPU has
108        # already been initialized.
109        if is_initialized():
110            return
111        # Stop promptly upon encountering a bad fork error.
112        if _is_in_bad_fork():
113            raise RuntimeError(
114                "Cannot re-initialize XPU in forked subprocess. To use XPU with "
115                "multiprocessing, you must use the 'spawn' start method"
116            )
117        if not _is_compiled():
118            raise AssertionError("Torch not compiled with XPU enabled")
119        # This function inits XPU backend and detects bad fork processing.
120        torch._C._xpu_init()
121        # Some of the queued calls may reentrantly call _lazy_init(); We need to
122        # just return without initializing in that case.
123        _tls.is_initializing = True
124
125        for calls in _lazy_seed_tracker.get_calls():
126            if calls:
127                _queued_calls.append(calls)
128
129        try:
130            for queued_call, orig_traceback in _queued_calls:
131                try:
132                    queued_call()
133                except Exception as e:
134                    msg = (
135                        f"XPU call failed lazily at initialization with error: {str(e)}\n\n"
136                        f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}"
137                    )
138                    raise Exception(msg) from e  # noqa: TRY002
139        finally:
140            delattr(_tls, "is_initializing")
141        _initialized = True
142
143
144class _DeviceGuard:
145    def __init__(self, index: int):
146        self.idx = index
147        self.prev_idx = -1
148
149    def __enter__(self):
150        self.prev_idx = torch.xpu._exchange_device(self.idx)
151
152    def __exit__(self, type: Any, value: Any, traceback: Any):
153        self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
154        return False
155
156
157class device:
158    r"""Context-manager that changes the selected device.
159
160    Args:
161        device (torch.device or int or str): device index to select. It's a no-op if
162            this argument is a negative integer or ``None``.
163    """
164
165    def __init__(self, device: Any):
166        self.idx = _get_device_index(device, optional=True)
167        self.prev_idx = -1
168
169    def __enter__(self):
170        self.prev_idx = torch.xpu._exchange_device(self.idx)
171
172    def __exit__(self, type: Any, value: Any, traceback: Any):
173        self.idx = torch.xpu._maybe_exchange_device(self.prev_idx)
174        return False
175
176
177class device_of(device):
178    r"""Context-manager that changes the current device to that of given object.
179
180    You can use both tensors and storages as arguments. If a given object is
181    not allocated on a XPU, this is a no-op.
182
183    Args:
184        obj (Tensor or Storage): object allocated on the selected device.
185    """
186
187    def __init__(self, obj):
188        idx = obj.get_device() if obj.is_xpu else -1
189        super().__init__(idx)
190
191
192def set_device(device: _device_t) -> None:
193    r"""Set the current device.
194
195    Args:
196        device (torch.device or int or str): selected device. This function is a
197            no-op if this argument is negative.
198    """
199    _lazy_init()
200    device = _get_device_index(device)
201    if device >= 0:
202        torch._C._xpu_setDevice(device)
203
204
205def get_device_name(device: Optional[_device_t] = None) -> str:
206    r"""Get the name of a device.
207
208    Args:
209        device (torch.device or int or str, optional): device for which to
210            return the name. This function is a no-op if this argument is a
211            negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`,
212            if :attr:`device` is ``None`` (default).
213
214    Returns:
215        str: the name of the device
216    """
217    return get_device_properties(device).name
218
219
220@lru_cache(None)
221def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]:
222    r"""Get the xpu capability of a device.
223
224    Args:
225        device (torch.device or int or str, optional): device for which to
226            return the device capability. This function is a no-op if this
227            argument is a negative integer. It uses the current device, given by
228            :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
229            (default).
230
231    Returns:
232        Dict[str, Any]: the xpu capability dictionary of the device
233    """
234    props = get_device_properties(device)
235    return {
236        prop: getattr(props, prop) for prop in dir(props) if not prop.startswith("__")
237    }
238
239
240def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties:
241    r"""Get the properties of a device.
242
243    Args:
244        device (torch.device or int or str): device for which to return the
245            properties of the device.
246
247    Returns:
248        _XpuDeviceProperties: the properties of the device
249    """
250    _lazy_init()
251    device = _get_device_index(device, optional=True)
252    if device < 0 or device >= device_count():
253        raise AssertionError("Invalid device index")
254    return _get_device_properties(device)  # type: ignore[name-defined]  # noqa: F821
255
256
257def current_device() -> int:
258    r"""Return the index of a currently selected device."""
259    _lazy_init()
260    return torch._C._xpu_getDevice()
261
262
263def _get_device(device: Union[int, str, torch.device]) -> torch.device:
264    r"""Return the torch.device type object from the passed in device.
265
266    Args:
267        device (torch.device or int or str): selected device.
268    """
269    if isinstance(device, str):
270        device = torch.device(device)
271    elif isinstance(device, int):
272        device = torch.device("xpu", device)
273    return device
274
275
276class StreamContext:
277    r"""Context-manager that selects a given stream.
278
279    All XPU kernels queued within its context will be enqueued on a selected
280    stream.
281
282    Args:
283        Stream (Stream): selected stream. This manager is a no-op if it's
284            ``None``.
285    .. note:: Streams are per-device.
286    """
287    cur_stream: Optional["torch.xpu.Stream"]
288
289    def __init__(self, stream: Optional["torch.xpu.Stream"]):
290        self.stream = stream
291        self.idx = _get_device_index(None, True)
292        if self.idx is None:
293            self.idx = -1
294
295    def __enter__(self):
296        cur_stream = self.stream
297        if cur_stream is None or self.idx == -1:
298            return
299        self.src_prev_stream = torch.xpu.current_stream(None)
300
301        # If the stream is not on the current device, then set the current stream on the device
302        if self.src_prev_stream.device != cur_stream.device:
303            with device(cur_stream.device):
304                self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device)
305        torch.xpu.set_stream(cur_stream)
306
307    def __exit__(self, type: Any, value: Any, traceback: Any):
308        cur_stream = self.stream
309        if cur_stream is None or self.idx == -1:
310            return
311
312        # Reset the stream on the original device and destination device
313        if self.src_prev_stream.device != cur_stream.device:
314            torch.xpu.set_stream(self.dst_prev_stream)
315        torch.xpu.set_stream(self.src_prev_stream)
316
317
318def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext:
319    r"""Wrap around the Context-manager StreamContext that selects a given stream.
320
321    Arguments:
322        stream (Stream): selected stream. This manager is a no-op if it's ``None``.
323    """
324    return StreamContext(stream)
325
326
327def _set_stream_by_id(stream_id, device_index, device_type):
328    r"""set stream specified by the stream id, device index and device type
329
330    Args: stream_id (int): not visible to the user, used to assigned to the specific stream.
331          device_index (int): selected device index.
332          device_type (int): selected device type.
333    """
334    torch._C._xpu_setStream(
335        stream_id=stream_id,
336        device_index=device_index,
337        device_type=device_type,
338    )
339
340
341def set_stream(stream: Stream):
342    r"""Set the current stream.This is a wrapper API to set the stream.
343        Usage of this function is discouraged in favor of the ``stream``
344        context manager.
345
346    Args:
347        stream (Stream): selected stream. This function is a no-op
348            if this argument is ``None``.
349    """
350    if stream is None:
351        return
352    _lazy_init()
353    _set_stream_by_id(
354        stream_id=stream.stream_id,
355        device_index=stream.device_index,
356        device_type=stream.device_type,
357    )
358
359
360def current_stream(device: Optional[_device_t] = None) -> Stream:
361    r"""Return the currently selected :class:`Stream` for a given device.
362
363    Args:
364        device (torch.device or int, optional): selected device. Returns
365            the currently selected :class:`Stream` for the current device, given
366            by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None``
367            (default).
368    """
369    _lazy_init()
370    streamdata = torch._C._xpu_getCurrentStream(
371        _get_device_index(device, optional=True)
372    )
373    return Stream(
374        stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2]
375    )
376
377
378def synchronize(device: _device_t = None) -> None:
379    r"""Wait for all kernels in all streams on a XPU device to complete.
380
381    Args:
382        device (torch.device or int, optional): device for which to synchronize.
383            It uses the current device, given by :func:`~torch.xpu.current_device`,
384            if :attr:`device` is ``None`` (default).
385    """
386    _lazy_init()
387    device = _get_device_index(device, optional=True)
388    return torch._C._xpu_synchronize(device)
389
390
391def _get_generator(device: torch.device) -> torch._C.Generator:
392    r"""Return the XPU Generator object for the given device.
393
394    Args:
395        device (torch.device): selected device.
396    """
397    idx = device.index
398    if idx is None:
399        idx = current_device()
400    return torch.xpu.default_generators[idx]
401
402
403def _set_rng_state_offset(
404    offset: int, device: Union[int, str, torch.device] = "xpu"
405) -> None:
406    r"""Set the random number generator state offset of the specified GPU.
407
408    Args:
409        offset (int): The desired offset
410        device (torch.device or int, optional): The device to set the RNG state.
411            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
412    """
413    final_device = _get_device(device)
414
415    def cb():
416        default_generator = _get_generator(final_device)
417        default_generator.set_offset(offset)
418
419    _lazy_call(cb)
420
421
422def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int:
423    r"""Return the random number generator state offset of the specified GPU.
424
425    Args:
426        device (torch.device or int, optional): The device to return the RNG state offset of.
427            Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device).
428
429    .. warning::
430        This function eagerly initializes XPU.
431    """
432    _lazy_init()
433    final_device = _get_device(device)
434    default_generator = _get_generator(final_device)
435    return default_generator.get_offset()
436
437
438# import here to avoid circular import
439from .memory import (
440    empty_cache,
441    max_memory_allocated,
442    max_memory_reserved,
443    memory_allocated,
444    memory_reserved,
445    memory_stats,
446    memory_stats_as_nested_dict,
447    reset_accumulated_memory_stats,
448    reset_peak_memory_stats,
449)
450from .random import (
451    get_rng_state,
452    get_rng_state_all,
453    initial_seed,
454    manual_seed,
455    manual_seed_all,
456    seed,
457    seed_all,
458    set_rng_state,
459    set_rng_state_all,
460)
461
462
463__all__ = [
464    "Event",
465    "Stream",
466    "StreamContext",
467    "current_device",
468    "current_stream",
469    "default_generators",
470    "device",
471    "device_of",
472    "device_count",
473    "empty_cache",
474    "get_device_capability",
475    "get_device_name",
476    "get_device_properties",
477    "get_rng_state",
478    "get_rng_state_all",
479    "get_stream",
480    "init",
481    "initial_seed",
482    "is_available",
483    "is_bf16_supported",
484    "is_initialized",
485    "manual_seed",
486    "manual_seed_all",
487    "max_memory_allocated",
488    "max_memory_reserved",
489    "memory_allocated",
490    "memory_reserved",
491    "memory_stats",
492    "memory_stats_as_nested_dict",
493    "reset_accumulated_memory_stats",
494    "reset_peak_memory_stats",
495    "seed",
496    "seed_all",
497    "set_device",
498    "set_rng_state",
499    "set_rng_state_all",
500    "set_stream",
501    "stream",
502    "streams",
503    "synchronize",
504]
505