1# mypy: allow-untyped-defs 2r""" 3This package introduces support for the XPU backend, specifically tailored for 4Intel GPU optimization. 5 6This package is lazily initialized, so you can always import it, and use 7:func:`is_available()` to determine if your system supports XPU. 8""" 9import threading 10import traceback 11from functools import lru_cache 12from typing import Any, Callable, Dict, List, Optional, Tuple, Union 13 14import torch 15import torch._C 16from torch import device as _device 17from torch._utils import _dummy_type, _LazySeedTracker 18 19from ._utils import _get_device_index 20from .streams import Event, Stream 21 22 23_initialized = False 24_tls = threading.local() 25_initialization_lock = threading.Lock() 26_queued_calls: List[ 27 Tuple[Callable[[], None], List[str]] 28] = [] # don't invoke these until initialization occurs 29_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False) 30_device_t = Union[_device, str, int, None] 31_lazy_seed_tracker = _LazySeedTracker() 32default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] 33 34 35def _is_compiled() -> bool: 36 r"""Return true if compile with XPU support.""" 37 return torch._C._has_xpu 38 39 40if _is_compiled(): 41 _XpuDeviceProperties = torch._C._XpuDeviceProperties 42 _exchange_device = torch._C._xpu_exchangeDevice 43 _maybe_exchange_device = torch._C._xpu_maybeExchangeDevice 44else: 45 # Define dummy if PyTorch was compiled without XPU 46 _XpuDeviceProperties = _dummy_type("_XpuDeviceProperties") # type: ignore[assignment, misc] 47 48 def _exchange_device(device: int) -> int: 49 raise NotImplementedError("PyTorch was compiled without XPU support") 50 51 def _maybe_exchange_device(device: int) -> int: 52 raise NotImplementedError("PyTorch was compiled without XPU support") 53 54 55@lru_cache(maxsize=1) 56def device_count() -> int: 57 r"""Return the number of XPU device available.""" 58 if not _is_compiled(): 59 return 0 60 return torch._C._xpu_getDeviceCount() 61 62 63def is_available() -> bool: 64 r"""Return a bool indicating if XPU is currently available.""" 65 # This function nerver throws. 66 return device_count() > 0 67 68 69def is_bf16_supported(): 70 r"""Return a bool indicating if the current XPU device supports dtype bfloat16.""" 71 return True 72 73 74def is_initialized(): 75 r"""Return whether PyTorch's XPU state has been initialized.""" 76 return _initialized and not _is_in_bad_fork() 77 78 79def _lazy_call(callable, **kwargs): 80 if is_initialized(): 81 callable() 82 else: 83 global _lazy_seed_tracker 84 if kwargs.get("seed_all", False): 85 _lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) 86 elif kwargs.get("seed", False): 87 _lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) 88 else: 89 # Don't store the actual traceback to avoid memory cycle 90 _queued_calls.append((callable, traceback.format_stack())) 91 92 93def init(): 94 r"""Initialize PyTorch's XPU state. 95 This is a Python API about lazy initialization that avoids initializing 96 XPU until the first time it is accessed. Does nothing if the XPU state is 97 already initialized. 98 """ 99 _lazy_init() 100 101 102def _lazy_init(): 103 global _initialized, _queued_calls 104 if is_initialized() or hasattr(_tls, "is_initializing"): 105 return 106 with _initialization_lock: 107 # This test was was protected via GIL. Double-check whether XPU has 108 # already been initialized. 109 if is_initialized(): 110 return 111 # Stop promptly upon encountering a bad fork error. 112 if _is_in_bad_fork(): 113 raise RuntimeError( 114 "Cannot re-initialize XPU in forked subprocess. To use XPU with " 115 "multiprocessing, you must use the 'spawn' start method" 116 ) 117 if not _is_compiled(): 118 raise AssertionError("Torch not compiled with XPU enabled") 119 # This function inits XPU backend and detects bad fork processing. 120 torch._C._xpu_init() 121 # Some of the queued calls may reentrantly call _lazy_init(); We need to 122 # just return without initializing in that case. 123 _tls.is_initializing = True 124 125 for calls in _lazy_seed_tracker.get_calls(): 126 if calls: 127 _queued_calls.append(calls) 128 129 try: 130 for queued_call, orig_traceback in _queued_calls: 131 try: 132 queued_call() 133 except Exception as e: 134 msg = ( 135 f"XPU call failed lazily at initialization with error: {str(e)}\n\n" 136 f"XPU call was originally invoked at:\n\n{''.join(orig_traceback)}" 137 ) 138 raise Exception(msg) from e # noqa: TRY002 139 finally: 140 delattr(_tls, "is_initializing") 141 _initialized = True 142 143 144class _DeviceGuard: 145 def __init__(self, index: int): 146 self.idx = index 147 self.prev_idx = -1 148 149 def __enter__(self): 150 self.prev_idx = torch.xpu._exchange_device(self.idx) 151 152 def __exit__(self, type: Any, value: Any, traceback: Any): 153 self.idx = torch.xpu._maybe_exchange_device(self.prev_idx) 154 return False 155 156 157class device: 158 r"""Context-manager that changes the selected device. 159 160 Args: 161 device (torch.device or int or str): device index to select. It's a no-op if 162 this argument is a negative integer or ``None``. 163 """ 164 165 def __init__(self, device: Any): 166 self.idx = _get_device_index(device, optional=True) 167 self.prev_idx = -1 168 169 def __enter__(self): 170 self.prev_idx = torch.xpu._exchange_device(self.idx) 171 172 def __exit__(self, type: Any, value: Any, traceback: Any): 173 self.idx = torch.xpu._maybe_exchange_device(self.prev_idx) 174 return False 175 176 177class device_of(device): 178 r"""Context-manager that changes the current device to that of given object. 179 180 You can use both tensors and storages as arguments. If a given object is 181 not allocated on a XPU, this is a no-op. 182 183 Args: 184 obj (Tensor or Storage): object allocated on the selected device. 185 """ 186 187 def __init__(self, obj): 188 idx = obj.get_device() if obj.is_xpu else -1 189 super().__init__(idx) 190 191 192def set_device(device: _device_t) -> None: 193 r"""Set the current device. 194 195 Args: 196 device (torch.device or int or str): selected device. This function is a 197 no-op if this argument is negative. 198 """ 199 _lazy_init() 200 device = _get_device_index(device) 201 if device >= 0: 202 torch._C._xpu_setDevice(device) 203 204 205def get_device_name(device: Optional[_device_t] = None) -> str: 206 r"""Get the name of a device. 207 208 Args: 209 device (torch.device or int or str, optional): device for which to 210 return the name. This function is a no-op if this argument is a 211 negative integer. It uses the current device, given by :func:`~torch.xpu.current_device`, 212 if :attr:`device` is ``None`` (default). 213 214 Returns: 215 str: the name of the device 216 """ 217 return get_device_properties(device).name 218 219 220@lru_cache(None) 221def get_device_capability(device: Optional[_device_t] = None) -> Dict[str, Any]: 222 r"""Get the xpu capability of a device. 223 224 Args: 225 device (torch.device or int or str, optional): device for which to 226 return the device capability. This function is a no-op if this 227 argument is a negative integer. It uses the current device, given by 228 :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` 229 (default). 230 231 Returns: 232 Dict[str, Any]: the xpu capability dictionary of the device 233 """ 234 props = get_device_properties(device) 235 return { 236 prop: getattr(props, prop) for prop in dir(props) if not prop.startswith("__") 237 } 238 239 240def get_device_properties(device: Optional[_device_t] = None) -> _XpuDeviceProperties: 241 r"""Get the properties of a device. 242 243 Args: 244 device (torch.device or int or str): device for which to return the 245 properties of the device. 246 247 Returns: 248 _XpuDeviceProperties: the properties of the device 249 """ 250 _lazy_init() 251 device = _get_device_index(device, optional=True) 252 if device < 0 or device >= device_count(): 253 raise AssertionError("Invalid device index") 254 return _get_device_properties(device) # type: ignore[name-defined] # noqa: F821 255 256 257def current_device() -> int: 258 r"""Return the index of a currently selected device.""" 259 _lazy_init() 260 return torch._C._xpu_getDevice() 261 262 263def _get_device(device: Union[int, str, torch.device]) -> torch.device: 264 r"""Return the torch.device type object from the passed in device. 265 266 Args: 267 device (torch.device or int or str): selected device. 268 """ 269 if isinstance(device, str): 270 device = torch.device(device) 271 elif isinstance(device, int): 272 device = torch.device("xpu", device) 273 return device 274 275 276class StreamContext: 277 r"""Context-manager that selects a given stream. 278 279 All XPU kernels queued within its context will be enqueued on a selected 280 stream. 281 282 Args: 283 Stream (Stream): selected stream. This manager is a no-op if it's 284 ``None``. 285 .. note:: Streams are per-device. 286 """ 287 cur_stream: Optional["torch.xpu.Stream"] 288 289 def __init__(self, stream: Optional["torch.xpu.Stream"]): 290 self.stream = stream 291 self.idx = _get_device_index(None, True) 292 if self.idx is None: 293 self.idx = -1 294 295 def __enter__(self): 296 cur_stream = self.stream 297 if cur_stream is None or self.idx == -1: 298 return 299 self.src_prev_stream = torch.xpu.current_stream(None) 300 301 # If the stream is not on the current device, then set the current stream on the device 302 if self.src_prev_stream.device != cur_stream.device: 303 with device(cur_stream.device): 304 self.dst_prev_stream = torch.xpu.current_stream(cur_stream.device) 305 torch.xpu.set_stream(cur_stream) 306 307 def __exit__(self, type: Any, value: Any, traceback: Any): 308 cur_stream = self.stream 309 if cur_stream is None or self.idx == -1: 310 return 311 312 # Reset the stream on the original device and destination device 313 if self.src_prev_stream.device != cur_stream.device: 314 torch.xpu.set_stream(self.dst_prev_stream) 315 torch.xpu.set_stream(self.src_prev_stream) 316 317 318def stream(stream: Optional["torch.xpu.Stream"]) -> StreamContext: 319 r"""Wrap around the Context-manager StreamContext that selects a given stream. 320 321 Arguments: 322 stream (Stream): selected stream. This manager is a no-op if it's ``None``. 323 """ 324 return StreamContext(stream) 325 326 327def _set_stream_by_id(stream_id, device_index, device_type): 328 r"""set stream specified by the stream id, device index and device type 329 330 Args: stream_id (int): not visible to the user, used to assigned to the specific stream. 331 device_index (int): selected device index. 332 device_type (int): selected device type. 333 """ 334 torch._C._xpu_setStream( 335 stream_id=stream_id, 336 device_index=device_index, 337 device_type=device_type, 338 ) 339 340 341def set_stream(stream: Stream): 342 r"""Set the current stream.This is a wrapper API to set the stream. 343 Usage of this function is discouraged in favor of the ``stream`` 344 context manager. 345 346 Args: 347 stream (Stream): selected stream. This function is a no-op 348 if this argument is ``None``. 349 """ 350 if stream is None: 351 return 352 _lazy_init() 353 _set_stream_by_id( 354 stream_id=stream.stream_id, 355 device_index=stream.device_index, 356 device_type=stream.device_type, 357 ) 358 359 360def current_stream(device: Optional[_device_t] = None) -> Stream: 361 r"""Return the currently selected :class:`Stream` for a given device. 362 363 Args: 364 device (torch.device or int, optional): selected device. Returns 365 the currently selected :class:`Stream` for the current device, given 366 by :func:`~torch.xpu.current_device`, if :attr:`device` is ``None`` 367 (default). 368 """ 369 _lazy_init() 370 streamdata = torch._C._xpu_getCurrentStream( 371 _get_device_index(device, optional=True) 372 ) 373 return Stream( 374 stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] 375 ) 376 377 378def synchronize(device: _device_t = None) -> None: 379 r"""Wait for all kernels in all streams on a XPU device to complete. 380 381 Args: 382 device (torch.device or int, optional): device for which to synchronize. 383 It uses the current device, given by :func:`~torch.xpu.current_device`, 384 if :attr:`device` is ``None`` (default). 385 """ 386 _lazy_init() 387 device = _get_device_index(device, optional=True) 388 return torch._C._xpu_synchronize(device) 389 390 391def _get_generator(device: torch.device) -> torch._C.Generator: 392 r"""Return the XPU Generator object for the given device. 393 394 Args: 395 device (torch.device): selected device. 396 """ 397 idx = device.index 398 if idx is None: 399 idx = current_device() 400 return torch.xpu.default_generators[idx] 401 402 403def _set_rng_state_offset( 404 offset: int, device: Union[int, str, torch.device] = "xpu" 405) -> None: 406 r"""Set the random number generator state offset of the specified GPU. 407 408 Args: 409 offset (int): The desired offset 410 device (torch.device or int, optional): The device to set the RNG state. 411 Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). 412 """ 413 final_device = _get_device(device) 414 415 def cb(): 416 default_generator = _get_generator(final_device) 417 default_generator.set_offset(offset) 418 419 _lazy_call(cb) 420 421 422def _get_rng_state_offset(device: Union[int, str, torch.device] = "xpu") -> int: 423 r"""Return the random number generator state offset of the specified GPU. 424 425 Args: 426 device (torch.device or int, optional): The device to return the RNG state offset of. 427 Default: ``'xpu'`` (i.e., ``torch.device('xpu')``, the current XPU device). 428 429 .. warning:: 430 This function eagerly initializes XPU. 431 """ 432 _lazy_init() 433 final_device = _get_device(device) 434 default_generator = _get_generator(final_device) 435 return default_generator.get_offset() 436 437 438# import here to avoid circular import 439from .memory import ( 440 empty_cache, 441 max_memory_allocated, 442 max_memory_reserved, 443 memory_allocated, 444 memory_reserved, 445 memory_stats, 446 memory_stats_as_nested_dict, 447 reset_accumulated_memory_stats, 448 reset_peak_memory_stats, 449) 450from .random import ( 451 get_rng_state, 452 get_rng_state_all, 453 initial_seed, 454 manual_seed, 455 manual_seed_all, 456 seed, 457 seed_all, 458 set_rng_state, 459 set_rng_state_all, 460) 461 462 463__all__ = [ 464 "Event", 465 "Stream", 466 "StreamContext", 467 "current_device", 468 "current_stream", 469 "default_generators", 470 "device", 471 "device_of", 472 "device_count", 473 "empty_cache", 474 "get_device_capability", 475 "get_device_name", 476 "get_device_properties", 477 "get_rng_state", 478 "get_rng_state_all", 479 "get_stream", 480 "init", 481 "initial_seed", 482 "is_available", 483 "is_bf16_supported", 484 "is_initialized", 485 "manual_seed", 486 "manual_seed_all", 487 "max_memory_allocated", 488 "max_memory_reserved", 489 "memory_allocated", 490 "memory_reserved", 491 "memory_stats", 492 "memory_stats_as_nested_dict", 493 "reset_accumulated_memory_stats", 494 "reset_peak_memory_stats", 495 "seed", 496 "seed_all", 497 "set_device", 498 "set_rng_state", 499 "set_rng_state_all", 500 "set_stream", 501 "stream", 502 "streams", 503 "synchronize", 504] 505