xref: /aosp_15_r20/external/pytorch/torch/xpu/streams.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import ctypes
3
4import torch
5from torch._streambase import _EventBase, _StreamBase
6
7from .._utils import _dummy_type
8
9
10if not hasattr(torch._C, "_XpuStreamBase"):
11    # Define dummy base classes
12    torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
13    torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
14
15
16class Stream(torch._C._XpuStreamBase, _StreamBase):
17    r"""Wrapper around a XPU stream.
18
19    A XPU stream is a linear sequence of execution that belongs to a specific
20    device, independent from other streams.
21
22    Args:
23        device(torch.device or int, optional): a device on which to allocate
24            the stream. If :attr:`device` is ``None`` (default) or a negative
25            integer, this will use the current device.
26        priority(int, optional): priority of the stream, should be 0 or
27            negative, where negative numbers indicate higher priority. By default,
28            streams have priority 0.
29    """
30
31    def __new__(cls, device=None, priority=0, **kwargs):
32        # setting device manager is expensive, so we avoid it unless necessary
33        if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
34            return super().__new__(cls, priority=priority, **kwargs)
35        else:
36            with torch.xpu.device(device):
37                return super().__new__(cls, priority=priority, **kwargs)
38
39    def wait_event(self, event) -> None:
40        r"""Make all future work submitted to the stream wait for an event.
41
42        Args:
43            event (torch.xpu.Event): an event to wait for.
44        """
45        event.wait(self)
46
47    def wait_stream(self, stream) -> None:
48        r"""Synchronize with another stream.
49
50        All future work submitted to this stream will wait until all kernels
51        submitted to a given stream at the time of call complete.
52
53        Args:
54            stream (Stream): a stream to synchronize.
55        """
56        self.wait_event(stream.record_event())
57
58    def record_event(self, event=None):
59        r"""Record an event.
60
61        Args:
62            event (torch.xpu.Event, optional): event to record. If not given, a new one
63                will be allocated.
64
65        Returns:
66            Recorded event.
67        """
68        if event is None:
69            event = Event()
70        event.record(self)
71        return event
72
73    def query(self) -> bool:
74        r"""Check if all the work submitted has been completed.
75
76        Returns:
77            A boolean indicating if all kernels in this stream are completed.
78        """
79        return super().query()
80
81    def synchronize(self) -> None:
82        r"""Wait for all the kernels in this stream to complete."""
83        super().synchronize()
84
85    @property
86    def _as_parameter_(self):
87        return ctypes.c_void_p(self.sycl_queue)
88
89    def __eq__(self, o):
90        if isinstance(o, Stream):
91            return super().__eq__(o)
92        return False
93
94    def __hash__(self):
95        return hash((self.sycl_queue, self.device))
96
97    def __repr__(self):
98        return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
99
100
101class Event(torch._C._XpuEventBase, _EventBase):
102    r"""Wrapper around a XPU event.
103
104    XPU events are synchronization markers that can be used to monitor the
105    device's progress, and to synchronize XPU streams.
106
107    The underlying XPU events are lazily initialized when the event is first
108    recorded. After creation, only streams on the same device may record the
109    event. However, streams on any device can wait on the event.
110
111    Args:
112        enable_timing (bool, optional): indicates if the event should measure time
113            (default: ``False``)
114    """
115
116    def __new__(cls, enable_timing=False):
117        return super().__new__(cls, enable_timing=enable_timing)
118
119    def record(self, stream=None) -> None:
120        r"""Record the event in a given stream.
121
122        Uses ``torch.xpu.current_stream()`` if no stream is specified. The
123        stream's device must match the event's device.
124        """
125        if stream is None:
126            stream = torch.xpu.current_stream()
127        super().record(stream)
128
129    def wait(self, stream=None) -> None:
130        r"""Make all future work submitted to the given stream wait for this event.
131
132        Use ``torch.xpu.current_stream()`` if no stream is specified.
133        """
134        if stream is None:
135            stream = torch.xpu.current_stream()
136        super().wait(stream)
137
138    def query(self) -> bool:
139        r"""Check if all work currently captured by event has completed.
140
141        Returns:
142            A boolean indicating if all work currently captured by event has
143            completed.
144        """
145        return super().query()
146
147    def elapsed_time(self, end_event):
148        r"""Return the time elapsed.
149
150        Time reported in milliseconds after the event was recorded and
151        before the end_event was recorded.
152        """
153        return super().elapsed_time(end_event)
154
155    def synchronize(self) -> None:
156        r"""Wait for the event to complete.
157
158        Waits until the completion of all work currently captured in this event.
159        This prevents the CPU thread from proceeding until the event completes.
160        """
161        super().synchronize()
162
163    @property
164    def _as_parameter_(self):
165        return ctypes.c_void_p(self.sycl_event)
166
167    def __repr__(self):
168        if self.sycl_event:
169            return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})"
170        else:
171            return "torch.xpu.Event(uninitialized)"
172