xref: /aosp_15_r20/external/pytorch/torch/xpu/streams.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport ctypes
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch._streambase import _EventBase, _StreamBase
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom .._utils import _dummy_type
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Workerif not hasattr(torch._C, "_XpuStreamBase"):
11*da0073e9SAndroid Build Coastguard Worker    # Define dummy base classes
12*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_XpuStreamBase"] = _dummy_type("_XpuStreamBase")
13*da0073e9SAndroid Build Coastguard Worker    torch._C.__dict__["_XpuEventBase"] = _dummy_type("_XpuEventBase")
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Workerclass Stream(torch._C._XpuStreamBase, _StreamBase):
17*da0073e9SAndroid Build Coastguard Worker    r"""Wrapper around a XPU stream.
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    A XPU stream is a linear sequence of execution that belongs to a specific
20*da0073e9SAndroid Build Coastguard Worker    device, independent from other streams.
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker    Args:
23*da0073e9SAndroid Build Coastguard Worker        device(torch.device or int, optional): a device on which to allocate
24*da0073e9SAndroid Build Coastguard Worker            the stream. If :attr:`device` is ``None`` (default) or a negative
25*da0073e9SAndroid Build Coastguard Worker            integer, this will use the current device.
26*da0073e9SAndroid Build Coastguard Worker        priority(int, optional): priority of the stream, should be 0 or
27*da0073e9SAndroid Build Coastguard Worker            negative, where negative numbers indicate higher priority. By default,
28*da0073e9SAndroid Build Coastguard Worker            streams have priority 0.
29*da0073e9SAndroid Build Coastguard Worker    """
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    def __new__(cls, device=None, priority=0, **kwargs):
32*da0073e9SAndroid Build Coastguard Worker        # setting device manager is expensive, so we avoid it unless necessary
33*da0073e9SAndroid Build Coastguard Worker        if device is None or ("stream_id" in kwargs and "device_index" in kwargs):
34*da0073e9SAndroid Build Coastguard Worker            return super().__new__(cls, priority=priority, **kwargs)
35*da0073e9SAndroid Build Coastguard Worker        else:
36*da0073e9SAndroid Build Coastguard Worker            with torch.xpu.device(device):
37*da0073e9SAndroid Build Coastguard Worker                return super().__new__(cls, priority=priority, **kwargs)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker    def wait_event(self, event) -> None:
40*da0073e9SAndroid Build Coastguard Worker        r"""Make all future work submitted to the stream wait for an event.
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker        Args:
43*da0073e9SAndroid Build Coastguard Worker            event (torch.xpu.Event): an event to wait for.
44*da0073e9SAndroid Build Coastguard Worker        """
45*da0073e9SAndroid Build Coastguard Worker        event.wait(self)
46*da0073e9SAndroid Build Coastguard Worker
47*da0073e9SAndroid Build Coastguard Worker    def wait_stream(self, stream) -> None:
48*da0073e9SAndroid Build Coastguard Worker        r"""Synchronize with another stream.
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker        All future work submitted to this stream will wait until all kernels
51*da0073e9SAndroid Build Coastguard Worker        submitted to a given stream at the time of call complete.
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker        Args:
54*da0073e9SAndroid Build Coastguard Worker            stream (Stream): a stream to synchronize.
55*da0073e9SAndroid Build Coastguard Worker        """
56*da0073e9SAndroid Build Coastguard Worker        self.wait_event(stream.record_event())
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Worker    def record_event(self, event=None):
59*da0073e9SAndroid Build Coastguard Worker        r"""Record an event.
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker        Args:
62*da0073e9SAndroid Build Coastguard Worker            event (torch.xpu.Event, optional): event to record. If not given, a new one
63*da0073e9SAndroid Build Coastguard Worker                will be allocated.
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker        Returns:
66*da0073e9SAndroid Build Coastguard Worker            Recorded event.
67*da0073e9SAndroid Build Coastguard Worker        """
68*da0073e9SAndroid Build Coastguard Worker        if event is None:
69*da0073e9SAndroid Build Coastguard Worker            event = Event()
70*da0073e9SAndroid Build Coastguard Worker        event.record(self)
71*da0073e9SAndroid Build Coastguard Worker        return event
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    def query(self) -> bool:
74*da0073e9SAndroid Build Coastguard Worker        r"""Check if all the work submitted has been completed.
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Worker        Returns:
77*da0073e9SAndroid Build Coastguard Worker            A boolean indicating if all kernels in this stream are completed.
78*da0073e9SAndroid Build Coastguard Worker        """
79*da0073e9SAndroid Build Coastguard Worker        return super().query()
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker    def synchronize(self) -> None:
82*da0073e9SAndroid Build Coastguard Worker        r"""Wait for all the kernels in this stream to complete."""
83*da0073e9SAndroid Build Coastguard Worker        super().synchronize()
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    @property
86*da0073e9SAndroid Build Coastguard Worker    def _as_parameter_(self):
87*da0073e9SAndroid Build Coastguard Worker        return ctypes.c_void_p(self.sycl_queue)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    def __eq__(self, o):
90*da0073e9SAndroid Build Coastguard Worker        if isinstance(o, Stream):
91*da0073e9SAndroid Build Coastguard Worker            return super().__eq__(o)
92*da0073e9SAndroid Build Coastguard Worker        return False
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    def __hash__(self):
95*da0073e9SAndroid Build Coastguard Worker        return hash((self.sycl_queue, self.device))
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
98*da0073e9SAndroid Build Coastguard Worker        return f"torch.xpu.Stream(device={self.device} sycl_queue={self.sycl_queue:#x})"
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerclass Event(torch._C._XpuEventBase, _EventBase):
102*da0073e9SAndroid Build Coastguard Worker    r"""Wrapper around a XPU event.
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker    XPU events are synchronization markers that can be used to monitor the
105*da0073e9SAndroid Build Coastguard Worker    device's progress, and to synchronize XPU streams.
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker    The underlying XPU events are lazily initialized when the event is first
108*da0073e9SAndroid Build Coastguard Worker    recorded. After creation, only streams on the same device may record the
109*da0073e9SAndroid Build Coastguard Worker    event. However, streams on any device can wait on the event.
110*da0073e9SAndroid Build Coastguard Worker
111*da0073e9SAndroid Build Coastguard Worker    Args:
112*da0073e9SAndroid Build Coastguard Worker        enable_timing (bool, optional): indicates if the event should measure time
113*da0073e9SAndroid Build Coastguard Worker            (default: ``False``)
114*da0073e9SAndroid Build Coastguard Worker    """
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker    def __new__(cls, enable_timing=False):
117*da0073e9SAndroid Build Coastguard Worker        return super().__new__(cls, enable_timing=enable_timing)
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Worker    def record(self, stream=None) -> None:
120*da0073e9SAndroid Build Coastguard Worker        r"""Record the event in a given stream.
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker        Uses ``torch.xpu.current_stream()`` if no stream is specified. The
123*da0073e9SAndroid Build Coastguard Worker        stream's device must match the event's device.
124*da0073e9SAndroid Build Coastguard Worker        """
125*da0073e9SAndroid Build Coastguard Worker        if stream is None:
126*da0073e9SAndroid Build Coastguard Worker            stream = torch.xpu.current_stream()
127*da0073e9SAndroid Build Coastguard Worker        super().record(stream)
128*da0073e9SAndroid Build Coastguard Worker
129*da0073e9SAndroid Build Coastguard Worker    def wait(self, stream=None) -> None:
130*da0073e9SAndroid Build Coastguard Worker        r"""Make all future work submitted to the given stream wait for this event.
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker        Use ``torch.xpu.current_stream()`` if no stream is specified.
133*da0073e9SAndroid Build Coastguard Worker        """
134*da0073e9SAndroid Build Coastguard Worker        if stream is None:
135*da0073e9SAndroid Build Coastguard Worker            stream = torch.xpu.current_stream()
136*da0073e9SAndroid Build Coastguard Worker        super().wait(stream)
137*da0073e9SAndroid Build Coastguard Worker
138*da0073e9SAndroid Build Coastguard Worker    def query(self) -> bool:
139*da0073e9SAndroid Build Coastguard Worker        r"""Check if all work currently captured by event has completed.
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        Returns:
142*da0073e9SAndroid Build Coastguard Worker            A boolean indicating if all work currently captured by event has
143*da0073e9SAndroid Build Coastguard Worker            completed.
144*da0073e9SAndroid Build Coastguard Worker        """
145*da0073e9SAndroid Build Coastguard Worker        return super().query()
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    def elapsed_time(self, end_event):
148*da0073e9SAndroid Build Coastguard Worker        r"""Return the time elapsed.
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        Time reported in milliseconds after the event was recorded and
151*da0073e9SAndroid Build Coastguard Worker        before the end_event was recorded.
152*da0073e9SAndroid Build Coastguard Worker        """
153*da0073e9SAndroid Build Coastguard Worker        return super().elapsed_time(end_event)
154*da0073e9SAndroid Build Coastguard Worker
155*da0073e9SAndroid Build Coastguard Worker    def synchronize(self) -> None:
156*da0073e9SAndroid Build Coastguard Worker        r"""Wait for the event to complete.
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        Waits until the completion of all work currently captured in this event.
159*da0073e9SAndroid Build Coastguard Worker        This prevents the CPU thread from proceeding until the event completes.
160*da0073e9SAndroid Build Coastguard Worker        """
161*da0073e9SAndroid Build Coastguard Worker        super().synchronize()
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker    @property
164*da0073e9SAndroid Build Coastguard Worker    def _as_parameter_(self):
165*da0073e9SAndroid Build Coastguard Worker        return ctypes.c_void_p(self.sycl_event)
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
168*da0073e9SAndroid Build Coastguard Worker        if self.sycl_event:
169*da0073e9SAndroid Build Coastguard Worker            return f"torch.xpu.Event(sycl_event={self.sycl_event:#x})"
170*da0073e9SAndroid Build Coastguard Worker        else:
171*da0073e9SAndroid Build Coastguard Worker            return "torch.xpu.Event(uninitialized)"
172