1# mypy: allow-untyped-defs 2import ctypes 3 4import torch 5from torch._streambase import _EventBase, _StreamBase 6 7from .._utils import _dummy_type 8 9 10if not hasattr(torch._C, "_XpuStreamBase"): 11 # Define dummy base classes 12 torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase") 13 torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase") 14 15 16class Stream(torch._C._XpuStreamBase, _StreamBase): 17 r"""Wrapper around a XPU stream. 18 19 A XPU stream is a linear sequence of execution that belongs to a specific 20 device, independent from other streams. 21 22 Args: 23 device(torch.device or int, optional): a device on which to allocate 24 the stream. If :attr:`device` is ``None`` (default) or a negative 25 integer, this will use the current device. 26 priority(int, optional): priority of the stream, should be 0 or 27 negative, where negative numbers indicate higher priority. By default, 28 streams have priority 0. 29 """ 30 31 def __new__(cls, device=None, priority=0, **kwargs): 32 # setting device manager is expensive, so we avoid it unless necessary 33 if device is None or ("stream_id" in kwargs and "device_index" in kwargs): 34 return super().__new__(cls, priority=priority, **kwargs) 35 else: 36 with torch.xpu.device(device): 37 return super().__new__(cls, priority=priority, **kwargs) 38 39 def wait_event(self, event) -> None: 40 r"""Make all future work submitted to the stream wait for an event. 41 42 Args: 43 event (torch.xpu.Event): an event to wait for. 44 """ 45 event.wait(self) 46 47 def wait_stream(self, stream) -> None: 48 r"""Synchronize with another stream. 49 50 All future work submitted to this stream will wait until all kernels 51 submitted to a given stream at the time of call complete. 52 53 Args: 54 stream (Stream): a stream to synchronize. 55 """ 56 self.wait_event(stream.record_event()) 57 58 def record_event(self, event=None): 59 r"""Record an event. 60 61 Args: 62 event (torch.xpu.Event, optional): event to record. If not given, a new one 63 will be allocated. 64 65 Returns: 66 Recorded event. 67 """ 68 if event is None: 69 event = Event() 70 event.record(self) 71 return event 72 73 def query(self) -> bool: 74 r"""Check if all the work submitted has been completed. 75 76 Returns: 77 A boolean indicating if all kernels in this stream are completed. 78 """ 79 return super().query() 80 81 def synchronize(self) -> None: 82 r"""Wait for all the kernels in this stream to complete.""" 83 super().synchronize() 84 85 @property 86 def _as_parameter_(self): 87 return ctypes.c_void_p(self.sycl_queue) 88 89 def __eq__(self, o): 90 if isinstance(o, Stream): 91 return super().__eq__(o) 92 return False 93 94 def __hash__(self): 95 return hash((self.sycl_queue, self.device)) 96 97 def __repr__(self): 98 return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})" 99 100 101class Event(torch._C._XpuEventBase, _EventBase): 102 r"""Wrapper around a XPU event. 103 104 XPU events are synchronization markers that can be used to monitor the 105 device's progress, and to synchronize XPU streams. 106 107 The underlying XPU events are lazily initialized when the event is first 108 recorded. After creation, only streams on the same device may record the 109 event. However, streams on any device can wait on the event. 110 111 Args: 112 enable_timing (bool, optional): indicates if the event should measure time 113 (default: ``False``) 114 """ 115 116 def __new__(cls, enable_timing=False): 117 return super().__new__(cls, enable_timing=enable_timing) 118 119 def record(self, stream=None) -> None: 120 r"""Record the event in a given stream. 121 122 Uses ``torch.xpu.current_stream()`` if no stream is specified. The 123 stream's device must match the event's device. 124 """ 125 if stream is None: 126 stream = torch.xpu.current_stream() 127 super().record(stream) 128 129 def wait(self, stream=None) -> None: 130 r"""Make all future work submitted to the given stream wait for this event. 131 132 Use ``torch.xpu.current_stream()`` if no stream is specified. 133 """ 134 if stream is None: 135 stream = torch.xpu.current_stream() 136 super().wait(stream) 137 138 def query(self) -> bool: 139 r"""Check if all work currently captured by event has completed. 140 141 Returns: 142 A boolean indicating if all work currently captured by event has 143 completed. 144 """ 145 return super().query() 146 147 def elapsed_time(self, end_event): 148 r"""Return the time elapsed. 149 150 Time reported in milliseconds after the event was recorded and 151 before the end_event was recorded. 152 """ 153 return super().elapsed_time(end_event) 154 155 def synchronize(self) -> None: 156 r"""Wait for the event to complete. 157 158 Waits until the completion of all work currently captured in this event. 159 This prevents the CPU thread from proceeding until the event completes. 160 """ 161 super().synchronize() 162 163 @property 164 def _as_parameter_(self): 165 return ctypes.c_void_p(self.sycl_event) 166 167 def __repr__(self): 168 if self.sycl_event: 169 return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})" 170 else: 171 return "torch.xpu.Event(uninitialized)" 172