1# mypy: allow-untyped-defs 2import inspect 3from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union 4 5import torch 6from torch._streambase import _EventBase, _StreamBase 7 8 9get_cuda_stream: Optional[Callable[[int], int]] 10if torch.cuda._is_compiled(): 11 from torch._C import _cuda_getCurrentRawStream as get_cuda_stream 12else: 13 get_cuda_stream = None 14 15_device_t = Union[torch.device, str, int, None] 16 17# Recording the device properties in the main process but used in worker process. 18caching_worker_device_properties: Dict[str, Any] = {} 19caching_worker_current_devices: Dict[str, int] = {} 20 21 22class DeviceInterfaceMeta(type): 23 def __new__(metacls, *args, **kwargs): 24 class_member = args[2] 25 if "Event" in class_member: 26 assert inspect.isclass(class_member["Event"]) and issubclass( 27 class_member["Event"], _EventBase 28 ), "DeviceInterface member Event should be inherit from _EventBase" 29 if "Stream" in class_member: 30 assert inspect.isclass(class_member["Stream"]) and issubclass( 31 class_member["Stream"], _StreamBase 32 ), "DeviceInterface member Stream should be inherit from _StreamBase" 33 return super().__new__(metacls, *args, **kwargs) 34 35 36class DeviceInterface(metaclass=DeviceInterfaceMeta): 37 """ 38 This is a simple device runtime interface for Inductor. It enables custom 39 backends to be integrated with Inductor in a device-agnostic semantic. 40 """ 41 42 class device: 43 def __new__(cls, device: _device_t): 44 raise NotImplementedError 45 46 class Worker: 47 """ 48 Worker API to query device properties that will work in multi processing 49 workers that cannot use the GPU APIs (due to processing fork() and 50 initialization time issues). Properties are recorded in the main process 51 before we fork the workers. 52 """ 53 54 @staticmethod 55 def set_device(device: int): 56 raise NotImplementedError 57 58 @staticmethod 59 def current_device() -> int: 60 raise NotImplementedError 61 62 @staticmethod 63 def get_device_properties(device: _device_t = None): 64 raise NotImplementedError 65 66 @staticmethod 67 def current_device(): 68 raise NotImplementedError 69 70 @staticmethod 71 def set_device(device: _device_t): 72 raise NotImplementedError 73 74 @staticmethod 75 def maybe_exchange_device(device: int) -> int: 76 raise NotImplementedError 77 78 @staticmethod 79 def exchange_device(device: int) -> int: 80 raise NotImplementedError 81 82 @staticmethod 83 def device_count(): 84 raise NotImplementedError 85 86 @staticmethod 87 def is_available() -> bool: 88 raise NotImplementedError 89 90 @staticmethod 91 def stream(stream: torch.Stream): 92 raise NotImplementedError 93 94 @staticmethod 95 def current_stream(): 96 raise NotImplementedError 97 98 @staticmethod 99 def set_stream(stream: torch.Stream): 100 raise NotImplementedError 101 102 @staticmethod 103 def _set_stream_by_id(stream_id: int, device_index: int, device_type: int): 104 raise NotImplementedError 105 106 @staticmethod 107 def get_raw_stream(device_idx: int) -> int: 108 raise NotImplementedError 109 110 @staticmethod 111 def synchronize(device: _device_t = None): 112 raise NotImplementedError 113 114 @staticmethod 115 def get_device_properties(device: _device_t = None): 116 raise NotImplementedError 117 118 @staticmethod 119 def get_compute_capability(device: _device_t = None): 120 raise NotImplementedError 121 122 @staticmethod 123 def is_bf16_supported(including_emulation: bool = False): 124 raise NotImplementedError 125 126 127class DeviceGuard: 128 """ 129 This class provides a context manager for device switching. This is a stripped 130 down version of torch.{device_name}.device. 131 132 The context manager changes the current device to the given device index 133 on entering the context and restores the original device on exiting. 134 The device is switched using the provided device interface. 135 """ 136 137 def __init__( 138 self, device_interface: Type[DeviceInterface], index: Optional[int] 139 ) -> None: 140 self.device_interface = device_interface 141 self.idx = index 142 self.prev_idx = -1 143 144 def __enter__(self): 145 if self.idx is not None: 146 self.prev_idx = self.device_interface.exchange_device(self.idx) 147 148 def __exit__(self, type: Any, value: Any, traceback: Any): 149 if self.idx is not None: 150 self.idx = self.device_interface.maybe_exchange_device(self.prev_idx) 151 return False 152 153 154class CudaInterface(DeviceInterface): 155 device = torch.cuda.device 156 157 # register Event and Stream class into the backend interface 158 # make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase 159 Event = torch.cuda.Event 160 Stream = torch.cuda.Stream 161 162 class Worker: 163 @staticmethod 164 def set_device(device: int): 165 caching_worker_current_devices["cuda"] = device 166 167 @staticmethod 168 def current_device() -> int: 169 if "cuda" in caching_worker_current_devices: 170 return caching_worker_current_devices["cuda"] 171 return torch.cuda.current_device() 172 173 @staticmethod 174 def get_device_properties(device: _device_t = None): 175 if device is not None: 176 if isinstance(device, str): 177 device = torch.device(device) 178 assert device.type == "cuda" 179 if isinstance(device, torch.device): 180 device = device.index 181 if device is None: 182 device = CudaInterface.Worker.current_device() 183 184 if "cuda" not in caching_worker_device_properties: 185 device_prop = [ 186 torch.cuda.get_device_properties(i) 187 for i in range(torch.cuda.device_count()) 188 ] 189 caching_worker_device_properties["cuda"] = device_prop 190 191 return caching_worker_device_properties["cuda"][device] 192 193 current_device = staticmethod(torch.cuda.current_device) 194 set_device = staticmethod(torch.cuda.set_device) 195 device_count = staticmethod(torch.cuda.device_count) 196 stream = staticmethod(torch.cuda.stream) # type: ignore[assignment] 197 current_stream = staticmethod(torch.cuda.current_stream) 198 set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment] 199 _set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment] 200 synchronize = staticmethod(torch.cuda.synchronize) 201 get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment] 202 get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[assignment, arg-type] 203 exchange_device = staticmethod(torch.cuda._exchange_device) # type: ignore[arg-type] 204 maybe_exchange_device = staticmethod(torch.cuda._maybe_exchange_device) # type: ignore[arg-type] 205 is_bf16_supported = staticmethod(torch.cuda.is_bf16_supported) # type: ignore[arg-type] 206 207 # Can be mock patched by @patch decorator. 208 @staticmethod 209 def is_available() -> bool: 210 return torch.cuda.is_available() 211 212 @staticmethod 213 def get_compute_capability(device: _device_t = None): 214 if torch.version.hip is None: 215 major, min = torch.cuda.get_device_capability(device) 216 return major * 10 + min 217 else: 218 return torch.cuda.get_device_properties(device).gcnArchName.split(":", 1)[0] 219 220 221get_xpu_stream: Optional[Callable[[int], int]] 222if torch.xpu._is_compiled(): 223 from torch._C import _xpu_getCurrentRawStream as get_xpu_stream 224else: 225 get_xpu_stream = None 226 227 228class XpuInterface(DeviceInterface): 229 device = torch.xpu.device 230 Event = torch.xpu.Event 231 Stream = torch.xpu.Stream 232 233 class Worker: 234 @staticmethod 235 def set_device(device: int): 236 caching_worker_current_devices["xpu"] = device 237 238 @staticmethod 239 def current_device() -> int: 240 if "xpu" in caching_worker_current_devices: 241 return caching_worker_current_devices["xpu"] 242 return torch.xpu.current_device() 243 244 @staticmethod 245 def get_device_properties(device: _device_t = None): 246 if device is not None: 247 if isinstance(device, str): 248 device = torch.device(device) 249 assert device.type == "xpu" 250 if isinstance(device, torch.device): 251 device = device.index 252 if device is None: 253 device = XpuInterface.Worker.current_device() 254 255 if "xpu" not in caching_worker_device_properties: 256 device_prop = [ 257 torch.xpu.get_device_properties(i) 258 for i in range(torch.xpu.device_count()) 259 ] 260 caching_worker_device_properties["xpu"] = device_prop 261 262 return caching_worker_device_properties["xpu"][device] 263 264 current_device = staticmethod(torch.xpu.current_device) 265 set_device = staticmethod(torch.xpu.set_device) 266 device_count = staticmethod(torch.xpu.device_count) 267 stream = staticmethod(torch.xpu.stream) # type: ignore[assignment] 268 current_stream = staticmethod(torch.xpu.current_stream) 269 set_stream = staticmethod(torch.xpu.set_stream) # type: ignore[assignment] 270 _set_stream_by_id = staticmethod(torch.xpu._set_stream_by_id) # type: ignore[assignment] 271 synchronize = staticmethod(torch.xpu.synchronize) 272 get_device_properties = staticmethod(torch.xpu.get_device_properties) # type: ignore[assignment] 273 get_raw_stream = staticmethod(get_xpu_stream) # type: ignore[assignment, arg-type] 274 exchange_device = staticmethod(torch.xpu._exchange_device) # type: ignore[arg-type] 275 maybe_exchange_device = staticmethod(torch.xpu._maybe_exchange_device) # type: ignore[arg-type] 276 277 # Can be mock patched by @patch decorator. 278 @staticmethod 279 def is_available() -> bool: 280 return torch.xpu.is_available() 281 282 @staticmethod 283 def get_compute_capability(device: _device_t = None): 284 cc = torch.xpu.get_device_capability(device) 285 return cc 286 287 @staticmethod 288 def is_bf16_supported(including_emulation: bool = False) -> bool: 289 return torch.xpu.is_bf16_supported() 290 291 292device_interfaces: Dict[str, Type[DeviceInterface]] = {} 293_device_initialized = False 294 295 296def register_interface_for_device( 297 device: Union[str, torch.device], device_interface: Type[DeviceInterface] 298): 299 if isinstance(device, torch.device): 300 device = str(device) 301 device_interfaces[device] = device_interface 302 303 304def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]: 305 if isinstance(device, torch.device): 306 device = str(device) 307 if not _device_initialized: 308 init_device_reg() 309 if device in device_interfaces: 310 return device_interfaces[device] 311 raise NotImplementedError(f"No interface for device {device}") 312 313 314def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]: 315 if not _device_initialized: 316 init_device_reg() 317 return device_interfaces.items() 318 319 320def init_device_reg(): 321 global _device_initialized 322 register_interface_for_device("cuda", CudaInterface) 323 for i in range(torch.cuda.device_count()): 324 register_interface_for_device(f"cuda:{i}", CudaInterface) 325 326 register_interface_for_device("xpu", XpuInterface) 327 for i in range(torch.xpu.device_count()): 328 register_interface_for_device(f"xpu:{i}", XpuInterface) 329 330 _device_initialized = True 331