xref: /aosp_15_r20/external/pytorch/c10/xpu/XPUStream.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Stream.h>
4 #include <c10/core/impl/GPUTrace.h>
5 #include <c10/xpu/XPUFunctions.h>
6 
7 namespace c10::xpu {
8 
9 /*
10  * Note [Stream Management]
11  *
12  * An XPUStream is an abstraction of an actual SYCL queue in which SYCL kernel
13  * can execute. Currently, there are several pools per device to manage SYCL
14  * queue, and a device's pool is lazily created.
15  *
16  * There are two pools per device. The first pool contains "normal priority"
17  * queues. The second pool is the "high priority" queues. There are 32 queues in
18  * per pool per device, and when a queue is requested one of these queues is
19  * returned round-robin. That is, the first queue requested is at index 0, the
20  * second at index 1... to index 31, then index 0 again.
21  *
22  * This means that if 33 queues are requested, the first and last queues
23  * requested are actually the same queue (under the covers) and kernels enqueued
24  * on them cannot run concurrently.
25  *
26  * It is safe to enqueue a kernel on the same queue from two different
27  * threads as the SYCL specification described.
28  */
29 
30 static constexpr int max_compile_time_stream_priorities = 2;
31 
32 /*
33  * This serves as a wrapper around c10::Stream and acts as a representation for
34  * a SYCL queue, which allows asynchronous execution of XPU tasks.
35  */
36 class C10_XPU_API XPUStream {
37  public:
38   enum Unchecked { UNCHECKED };
39 
40   /// Construct a XPUStream from a Stream. This construction is checked, and
41   /// will raise an error if the Stream is not, in fact, a XPU stream.
XPUStream(Stream stream)42   explicit XPUStream(Stream stream) : stream_(stream) {
43     TORCH_CHECK(stream_.device_type() == DeviceType::XPU);
44   }
45 
46   /// Construct a XPUStream from a Stream with no error checking.
XPUStream(Unchecked,Stream stream)47   explicit XPUStream(Unchecked, Stream stream) : stream_(stream) {}
48 
49   bool operator==(const XPUStream& other) const noexcept {
50     return unwrap() == other.unwrap();
51   }
52 
53   bool operator!=(const XPUStream& other) const noexcept {
54     return unwrap() != other.unwrap();
55   }
56 
57   /// Implicit conversion to sycl::queue&.
58   operator sycl::queue&() const {
59     return queue();
60   }
61 
62   /// Implicit conversion to Stream (a.k.a., forget that the stream is a
63   /// XPU stream).
Stream()64   operator Stream() const {
65     return unwrap();
66   }
67 
68   /// Get the XPU device type that this stream is associated with.
device_type()69   DeviceType device_type() const {
70     return DeviceType::XPU;
71   }
72 
73   /// Get the XPU device index that this stream is associated with.
device_index()74   DeviceIndex device_index() const {
75     return stream_.device_index();
76   }
77 
78   /// Get the full Device that this stream is associated with. The Device is
79   /// guaranteed to be a XPU device.
device()80   Device device() const {
81     return Device(DeviceType::XPU, device_index());
82   }
83 
84   /// Return the stream ID corresponding to this particular stream. StreamId is
85   /// a int64_t representation generated by its type and index.
id()86   StreamId id() const {
87     return stream_.id();
88   }
89 
90   /// Return true if all enqueued tasks in this stream have been completed,
91   /// otherwise return false.
query()92   bool query() const {
93     return queue().ext_oneapi_empty();
94   }
95 
96   /// Performs a blocking wait for the completion of all enqueued tasks in this
97   /// stream.
synchronize()98   void synchronize() const {
99     queue().wait_and_throw();
100     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
101     if (C10_UNLIKELY(interp)) {
102       (*interp)->trace_gpu_stream_synchronization(
103           c10::kXPU, reinterpret_cast<uintptr_t>(&queue()));
104     }
105   }
106 
107   /// Return the priority that this stream is associated with. Lower numbers
108   /// represent higher priority.
109   int priority() const;
110 
111   /// Explicit conversion to sycl::queue&.
112   sycl::queue& queue() const;
113 
114   /// Explicit conversion to Stream.
unwrap()115   Stream unwrap() const {
116     return stream_;
117   }
118 
119   /// Reversibly pack a XPUStream into a struct representation. The XPUStream
120   /// can be unpacked using unpack3().
pack3()121   struct c10::StreamData3 pack3() const {
122     return stream_.pack3();
123   }
124 
125   /// Unpack a XPUStream from the 3 fields generated by pack3().
unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)126   static XPUStream unpack3(
127       StreamId stream_id,
128       DeviceIndex device_index,
129       DeviceType device_type) {
130     return XPUStream(Stream::unpack3(stream_id, device_index, device_type));
131   }
132 
133   /// Return the range of priority **supported by PyTorch**.
priority_range()134   static std::tuple<int, int> priority_range() {
135     return std::make_tuple(0, -max_compile_time_stream_priorities + 1);
136   }
137 
138  private:
139   Stream stream_;
140 };
141 
142 /**
143  * Get a stream from the pool in a round-robin fashion.
144  *
145  * You can request a stream from the highest priority pool by setting
146  * isHighPriority to true for a specific device.
147  */
148 C10_XPU_API XPUStream
149 getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
150 
151 /**
152  * Get a stream from the pool in a round-robin fashion.
153  *
154  * You can request a stream by setting a priority value for a specific device.
155  * The priority number lower, the priority higher.
156  */
157 C10_XPU_API XPUStream
158 getStreamFromPool(const int priority, DeviceIndex device = -1);
159 
160 /**
161  * Get the current XPU stream, for the passed XPU device, or for the current
162  * device if no device index is passed.
163  */
164 C10_XPU_API XPUStream getCurrentXPUStream(DeviceIndex device = -1);
165 
166 /**
167  * Set the current stream on the device of the passed in stream to be the passed
168  * in stream.
169  */
170 C10_XPU_API void setCurrentXPUStream(XPUStream stream);
171 
172 C10_XPU_API std::ostream& operator<<(std::ostream& stream, const XPUStream& s);
173 
174 /**
175  * Block all reserved SYCL queues in the stream pools on the device, and wait
176  * for their synchronizations.
177  */
178 C10_XPU_API void syncStreamsOnDevice(DeviceIndex device = -1);
179 
180 } // namespace c10::xpu
181 
182 namespace std {
183 template <>
184 struct hash<c10::xpu::XPUStream> {
185   size_t operator()(c10::xpu::XPUStream s) const noexcept {
186     return std::hash<c10::Stream>{}(s.unwrap());
187   }
188 };
189 } // namespace std
190