1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport ctypes 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch._streambase import _EventBase, _StreamBase 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom .._utils import _dummy_type 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Workerif not hasattr(torch._C, "_XpuStreamBase"): 11*da0073e9SAndroid Build Coastguard Worker # Define dummy base classes 12*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase") 13*da0073e9SAndroid Build Coastguard Worker torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase") 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Workerclass Stream(torch._C._XpuStreamBase, _StreamBase): 17*da0073e9SAndroid Build Coastguard Worker r"""Wrapper around a XPU stream. 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker A XPU stream is a linear sequence of execution that belongs to a specific 20*da0073e9SAndroid Build Coastguard Worker device, independent from other streams. 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker Args: 23*da0073e9SAndroid Build Coastguard Worker device(torch.device or int, optional): a device on which to allocate 24*da0073e9SAndroid Build Coastguard Worker the stream. If :attr:`device` is ``None`` (default) or a negative 25*da0073e9SAndroid Build Coastguard Worker integer, this will use the current device. 26*da0073e9SAndroid Build Coastguard Worker priority(int, optional): priority of the stream, should be 0 or 27*da0073e9SAndroid Build Coastguard Worker negative, where negative numbers indicate higher priority. By default, 28*da0073e9SAndroid Build Coastguard Worker streams have priority 0. 29*da0073e9SAndroid Build Coastguard Worker """ 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker def __new__(cls, device=None, priority=0, **kwargs): 32*da0073e9SAndroid Build Coastguard Worker # setting device manager is expensive, so we avoid it unless necessary 33*da0073e9SAndroid Build Coastguard Worker if device is None or ("stream_id" in kwargs and "device_index" in kwargs): 34*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls, priority=priority, **kwargs) 35*da0073e9SAndroid Build Coastguard Worker else: 36*da0073e9SAndroid Build Coastguard Worker with torch.xpu.device(device): 37*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls, priority=priority, **kwargs) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker def wait_event(self, event) -> None: 40*da0073e9SAndroid Build Coastguard Worker r"""Make all future work submitted to the stream wait for an event. 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker Args: 43*da0073e9SAndroid Build Coastguard Worker event (torch.xpu.Event): an event to wait for. 44*da0073e9SAndroid Build Coastguard Worker """ 45*da0073e9SAndroid Build Coastguard Worker event.wait(self) 46*da0073e9SAndroid Build Coastguard Worker 47*da0073e9SAndroid Build Coastguard Worker def wait_stream(self, stream) -> None: 48*da0073e9SAndroid Build Coastguard Worker r"""Synchronize with another stream. 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker All future work submitted to this stream will wait until all kernels 51*da0073e9SAndroid Build Coastguard Worker submitted to a given stream at the time of call complete. 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker Args: 54*da0073e9SAndroid Build Coastguard Worker stream (Stream): a stream to synchronize. 55*da0073e9SAndroid Build Coastguard Worker """ 56*da0073e9SAndroid Build Coastguard Worker self.wait_event(stream.record_event()) 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Worker def record_event(self, event=None): 59*da0073e9SAndroid Build Coastguard Worker r"""Record an event. 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker Args: 62*da0073e9SAndroid Build Coastguard Worker event (torch.xpu.Event, optional): event to record. If not given, a new one 63*da0073e9SAndroid Build Coastguard Worker will be allocated. 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker Returns: 66*da0073e9SAndroid Build Coastguard Worker Recorded event. 67*da0073e9SAndroid Build Coastguard Worker """ 68*da0073e9SAndroid Build Coastguard Worker if event is None: 69*da0073e9SAndroid Build Coastguard Worker event = Event() 70*da0073e9SAndroid Build Coastguard Worker event.record(self) 71*da0073e9SAndroid Build Coastguard Worker return event 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def query(self) -> bool: 74*da0073e9SAndroid Build Coastguard Worker r"""Check if all the work submitted has been completed. 75*da0073e9SAndroid Build Coastguard Worker 76*da0073e9SAndroid Build Coastguard Worker Returns: 77*da0073e9SAndroid Build Coastguard Worker A boolean indicating if all kernels in this stream are completed. 78*da0073e9SAndroid Build Coastguard Worker """ 79*da0073e9SAndroid Build Coastguard Worker return super().query() 80*da0073e9SAndroid Build Coastguard Worker 81*da0073e9SAndroid Build Coastguard Worker def synchronize(self) -> None: 82*da0073e9SAndroid Build Coastguard Worker r"""Wait for all the kernels in this stream to complete.""" 83*da0073e9SAndroid Build Coastguard Worker super().synchronize() 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker @property 86*da0073e9SAndroid Build Coastguard Worker def _as_parameter_(self): 87*da0073e9SAndroid Build Coastguard Worker return ctypes.c_void_p(self.sycl_queue) 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker def __eq__(self, o): 90*da0073e9SAndroid Build Coastguard Worker if isinstance(o, Stream): 91*da0073e9SAndroid Build Coastguard Worker return super().__eq__(o) 92*da0073e9SAndroid Build Coastguard Worker return False 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker def __hash__(self): 95*da0073e9SAndroid Build Coastguard Worker return hash((self.sycl_queue, self.device)) 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 98*da0073e9SAndroid Build Coastguard Worker return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})" 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerclass Event(torch._C._XpuEventBase, _EventBase): 102*da0073e9SAndroid Build Coastguard Worker r"""Wrapper around a XPU event. 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker XPU events are synchronization markers that can be used to monitor the 105*da0073e9SAndroid Build Coastguard Worker device's progress, and to synchronize XPU streams. 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker The underlying XPU events are lazily initialized when the event is first 108*da0073e9SAndroid Build Coastguard Worker recorded. After creation, only streams on the same device may record the 109*da0073e9SAndroid Build Coastguard Worker event. However, streams on any device can wait on the event. 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker Args: 112*da0073e9SAndroid Build Coastguard Worker enable_timing (bool, optional): indicates if the event should measure time 113*da0073e9SAndroid Build Coastguard Worker (default: ``False``) 114*da0073e9SAndroid Build Coastguard Worker """ 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker def __new__(cls, enable_timing=False): 117*da0073e9SAndroid Build Coastguard Worker return super().__new__(cls, enable_timing=enable_timing) 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker def record(self, stream=None) -> None: 120*da0073e9SAndroid Build Coastguard Worker r"""Record the event in a given stream. 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker Uses ``torch.xpu.current_stream()`` if no stream is specified. The 123*da0073e9SAndroid Build Coastguard Worker stream's device must match the event's device. 124*da0073e9SAndroid Build Coastguard Worker """ 125*da0073e9SAndroid Build Coastguard Worker if stream is None: 126*da0073e9SAndroid Build Coastguard Worker stream = torch.xpu.current_stream() 127*da0073e9SAndroid Build Coastguard Worker super().record(stream) 128*da0073e9SAndroid Build Coastguard Worker 129*da0073e9SAndroid Build Coastguard Worker def wait(self, stream=None) -> None: 130*da0073e9SAndroid Build Coastguard Worker r"""Make all future work submitted to the given stream wait for this event. 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker Use ``torch.xpu.current_stream()`` if no stream is specified. 133*da0073e9SAndroid Build Coastguard Worker """ 134*da0073e9SAndroid Build Coastguard Worker if stream is None: 135*da0073e9SAndroid Build Coastguard Worker stream = torch.xpu.current_stream() 136*da0073e9SAndroid Build Coastguard Worker super().wait(stream) 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker def query(self) -> bool: 139*da0073e9SAndroid Build Coastguard Worker r"""Check if all work currently captured by event has completed. 140*da0073e9SAndroid Build Coastguard Worker 141*da0073e9SAndroid Build Coastguard Worker Returns: 142*da0073e9SAndroid Build Coastguard Worker A boolean indicating if all work currently captured by event has 143*da0073e9SAndroid Build Coastguard Worker completed. 144*da0073e9SAndroid Build Coastguard Worker """ 145*da0073e9SAndroid Build Coastguard Worker return super().query() 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker def elapsed_time(self, end_event): 148*da0073e9SAndroid Build Coastguard Worker r"""Return the time elapsed. 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker Time reported in milliseconds after the event was recorded and 151*da0073e9SAndroid Build Coastguard Worker before the end_event was recorded. 152*da0073e9SAndroid Build Coastguard Worker """ 153*da0073e9SAndroid Build Coastguard Worker return super().elapsed_time(end_event) 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker def synchronize(self) -> None: 156*da0073e9SAndroid Build Coastguard Worker r"""Wait for the event to complete. 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker Waits until the completion of all work currently captured in this event. 159*da0073e9SAndroid Build Coastguard Worker This prevents the CPU thread from proceeding until the event completes. 160*da0073e9SAndroid Build Coastguard Worker """ 161*da0073e9SAndroid Build Coastguard Worker super().synchronize() 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker @property 164*da0073e9SAndroid Build Coastguard Worker def _as_parameter_(self): 165*da0073e9SAndroid Build Coastguard Worker return ctypes.c_void_p(self.sycl_event) 166*da0073e9SAndroid Build Coastguard Worker 167*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 168*da0073e9SAndroid Build Coastguard Worker if self.sycl_event: 169*da0073e9SAndroid Build Coastguard Worker return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})" 170*da0073e9SAndroid Build Coastguard Worker else: 171*da0073e9SAndroid Build Coastguard Worker return "torch.xpu.Event(uninitialized)" 172