xref: /aosp_15_r20/external/pytorch/c10/xpu/XPUStream.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Stream.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h>
5*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUFunctions.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu {
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker /*
10*da0073e9SAndroid Build Coastguard Worker  * Note [Stream Management]
11*da0073e9SAndroid Build Coastguard Worker  *
12*da0073e9SAndroid Build Coastguard Worker  * An XPUStream is an abstraction of an actual SYCL queue in which SYCL kernel
13*da0073e9SAndroid Build Coastguard Worker  * can execute. Currently, there are several pools per device to manage SYCL
14*da0073e9SAndroid Build Coastguard Worker  * queue, and a device's pool is lazily created.
15*da0073e9SAndroid Build Coastguard Worker  *
16*da0073e9SAndroid Build Coastguard Worker  * There are two pools per device. The first pool contains "normal priority"
17*da0073e9SAndroid Build Coastguard Worker  * queues. The second pool is the "high priority" queues. There are 32 queues in
18*da0073e9SAndroid Build Coastguard Worker  * per pool per device, and when a queue is requested one of these queues is
19*da0073e9SAndroid Build Coastguard Worker  * returned round-robin. That is, the first queue requested is at index 0, the
20*da0073e9SAndroid Build Coastguard Worker  * second at index 1... to index 31, then index 0 again.
21*da0073e9SAndroid Build Coastguard Worker  *
22*da0073e9SAndroid Build Coastguard Worker  * This means that if 33 queues are requested, the first and last queues
23*da0073e9SAndroid Build Coastguard Worker  * requested are actually the same queue (under the covers) and kernels enqueued
24*da0073e9SAndroid Build Coastguard Worker  * on them cannot run concurrently.
25*da0073e9SAndroid Build Coastguard Worker  *
26*da0073e9SAndroid Build Coastguard Worker  * It is safe to enqueue a kernel on the same queue from two different
27*da0073e9SAndroid Build Coastguard Worker  * threads as the SYCL specification described.
28*da0073e9SAndroid Build Coastguard Worker  */
29*da0073e9SAndroid Build Coastguard Worker 
30*da0073e9SAndroid Build Coastguard Worker static constexpr int max_compile_time_stream_priorities = 2;
31*da0073e9SAndroid Build Coastguard Worker 
32*da0073e9SAndroid Build Coastguard Worker /*
33*da0073e9SAndroid Build Coastguard Worker  * This serves as a wrapper around c10::Stream and acts as a representation for
34*da0073e9SAndroid Build Coastguard Worker  * a SYCL queue, which allows asynchronous execution of XPU tasks.
35*da0073e9SAndroid Build Coastguard Worker  */
36*da0073e9SAndroid Build Coastguard Worker class C10_XPU_API XPUStream {
37*da0073e9SAndroid Build Coastguard Worker  public:
38*da0073e9SAndroid Build Coastguard Worker   enum Unchecked { UNCHECKED };
39*da0073e9SAndroid Build Coastguard Worker 
40*da0073e9SAndroid Build Coastguard Worker   /// Construct a XPUStream from a Stream. This construction is checked, and
41*da0073e9SAndroid Build Coastguard Worker   /// will raise an error if the Stream is not, in fact, a XPU stream.
XPUStream(Stream stream)42*da0073e9SAndroid Build Coastguard Worker   explicit XPUStream(Stream stream) : stream_(stream) {
43*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(stream_.device_type() == DeviceType::XPU);
44*da0073e9SAndroid Build Coastguard Worker   }
45*da0073e9SAndroid Build Coastguard Worker 
46*da0073e9SAndroid Build Coastguard Worker   /// Construct a XPUStream from a Stream with no error checking.
XPUStream(Unchecked,Stream stream)47*da0073e9SAndroid Build Coastguard Worker   explicit XPUStream(Unchecked, Stream stream) : stream_(stream) {}
48*da0073e9SAndroid Build Coastguard Worker 
49*da0073e9SAndroid Build Coastguard Worker   bool operator==(const XPUStream& other) const noexcept {
50*da0073e9SAndroid Build Coastguard Worker     return unwrap() == other.unwrap();
51*da0073e9SAndroid Build Coastguard Worker   }
52*da0073e9SAndroid Build Coastguard Worker 
53*da0073e9SAndroid Build Coastguard Worker   bool operator!=(const XPUStream& other) const noexcept {
54*da0073e9SAndroid Build Coastguard Worker     return unwrap() != other.unwrap();
55*da0073e9SAndroid Build Coastguard Worker   }
56*da0073e9SAndroid Build Coastguard Worker 
57*da0073e9SAndroid Build Coastguard Worker   /// Implicit conversion to sycl::queue&.
58*da0073e9SAndroid Build Coastguard Worker   operator sycl::queue&() const {
59*da0073e9SAndroid Build Coastguard Worker     return queue();
60*da0073e9SAndroid Build Coastguard Worker   }
61*da0073e9SAndroid Build Coastguard Worker 
62*da0073e9SAndroid Build Coastguard Worker   /// Implicit conversion to Stream (a.k.a., forget that the stream is a
63*da0073e9SAndroid Build Coastguard Worker   /// XPU stream).
Stream()64*da0073e9SAndroid Build Coastguard Worker   operator Stream() const {
65*da0073e9SAndroid Build Coastguard Worker     return unwrap();
66*da0073e9SAndroid Build Coastguard Worker   }
67*da0073e9SAndroid Build Coastguard Worker 
68*da0073e9SAndroid Build Coastguard Worker   /// Get the XPU device type that this stream is associated with.
device_type()69*da0073e9SAndroid Build Coastguard Worker   DeviceType device_type() const {
70*da0073e9SAndroid Build Coastguard Worker     return DeviceType::XPU;
71*da0073e9SAndroid Build Coastguard Worker   }
72*da0073e9SAndroid Build Coastguard Worker 
73*da0073e9SAndroid Build Coastguard Worker   /// Get the XPU device index that this stream is associated with.
device_index()74*da0073e9SAndroid Build Coastguard Worker   DeviceIndex device_index() const {
75*da0073e9SAndroid Build Coastguard Worker     return stream_.device_index();
76*da0073e9SAndroid Build Coastguard Worker   }
77*da0073e9SAndroid Build Coastguard Worker 
78*da0073e9SAndroid Build Coastguard Worker   /// Get the full Device that this stream is associated with. The Device is
79*da0073e9SAndroid Build Coastguard Worker   /// guaranteed to be a XPU device.
device()80*da0073e9SAndroid Build Coastguard Worker   Device device() const {
81*da0073e9SAndroid Build Coastguard Worker     return Device(DeviceType::XPU, device_index());
82*da0073e9SAndroid Build Coastguard Worker   }
83*da0073e9SAndroid Build Coastguard Worker 
84*da0073e9SAndroid Build Coastguard Worker   /// Return the stream ID corresponding to this particular stream. StreamId is
85*da0073e9SAndroid Build Coastguard Worker   /// a int64_t representation generated by its type and index.
id()86*da0073e9SAndroid Build Coastguard Worker   StreamId id() const {
87*da0073e9SAndroid Build Coastguard Worker     return stream_.id();
88*da0073e9SAndroid Build Coastguard Worker   }
89*da0073e9SAndroid Build Coastguard Worker 
90*da0073e9SAndroid Build Coastguard Worker   /// Return true if all enqueued tasks in this stream have been completed,
91*da0073e9SAndroid Build Coastguard Worker   /// otherwise return false.
query()92*da0073e9SAndroid Build Coastguard Worker   bool query() const {
93*da0073e9SAndroid Build Coastguard Worker     return queue().ext_oneapi_empty();
94*da0073e9SAndroid Build Coastguard Worker   }
95*da0073e9SAndroid Build Coastguard Worker 
96*da0073e9SAndroid Build Coastguard Worker   /// Performs a blocking wait for the completion of all enqueued tasks in this
97*da0073e9SAndroid Build Coastguard Worker   /// stream.
synchronize()98*da0073e9SAndroid Build Coastguard Worker   void synchronize() const {
99*da0073e9SAndroid Build Coastguard Worker     queue().wait_and_throw();
100*da0073e9SAndroid Build Coastguard Worker     const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
101*da0073e9SAndroid Build Coastguard Worker     if (C10_UNLIKELY(interp)) {
102*da0073e9SAndroid Build Coastguard Worker       (*interp)->trace_gpu_stream_synchronization(
103*da0073e9SAndroid Build Coastguard Worker           c10::kXPU, reinterpret_cast<uintptr_t>(&queue()));
104*da0073e9SAndroid Build Coastguard Worker     }
105*da0073e9SAndroid Build Coastguard Worker   }
106*da0073e9SAndroid Build Coastguard Worker 
107*da0073e9SAndroid Build Coastguard Worker   /// Return the priority that this stream is associated with. Lower numbers
108*da0073e9SAndroid Build Coastguard Worker   /// represent higher priority.
109*da0073e9SAndroid Build Coastguard Worker   int priority() const;
110*da0073e9SAndroid Build Coastguard Worker 
111*da0073e9SAndroid Build Coastguard Worker   /// Explicit conversion to sycl::queue&.
112*da0073e9SAndroid Build Coastguard Worker   sycl::queue& queue() const;
113*da0073e9SAndroid Build Coastguard Worker 
114*da0073e9SAndroid Build Coastguard Worker   /// Explicit conversion to Stream.
unwrap()115*da0073e9SAndroid Build Coastguard Worker   Stream unwrap() const {
116*da0073e9SAndroid Build Coastguard Worker     return stream_;
117*da0073e9SAndroid Build Coastguard Worker   }
118*da0073e9SAndroid Build Coastguard Worker 
119*da0073e9SAndroid Build Coastguard Worker   /// Reversibly pack a XPUStream into a struct representation. The XPUStream
120*da0073e9SAndroid Build Coastguard Worker   /// can be unpacked using unpack3().
pack3()121*da0073e9SAndroid Build Coastguard Worker   struct c10::StreamData3 pack3() const {
122*da0073e9SAndroid Build Coastguard Worker     return stream_.pack3();
123*da0073e9SAndroid Build Coastguard Worker   }
124*da0073e9SAndroid Build Coastguard Worker 
125*da0073e9SAndroid Build Coastguard Worker   /// Unpack a XPUStream from the 3 fields generated by pack3().
unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)126*da0073e9SAndroid Build Coastguard Worker   static XPUStream unpack3(
127*da0073e9SAndroid Build Coastguard Worker       StreamId stream_id,
128*da0073e9SAndroid Build Coastguard Worker       DeviceIndex device_index,
129*da0073e9SAndroid Build Coastguard Worker       DeviceType device_type) {
130*da0073e9SAndroid Build Coastguard Worker     return XPUStream(Stream::unpack3(stream_id, device_index, device_type));
131*da0073e9SAndroid Build Coastguard Worker   }
132*da0073e9SAndroid Build Coastguard Worker 
133*da0073e9SAndroid Build Coastguard Worker   /// Return the range of priority **supported by PyTorch**.
priority_range()134*da0073e9SAndroid Build Coastguard Worker   static std::tuple<int, int> priority_range() {
135*da0073e9SAndroid Build Coastguard Worker     return std::make_tuple(0, -max_compile_time_stream_priorities + 1);
136*da0073e9SAndroid Build Coastguard Worker   }
137*da0073e9SAndroid Build Coastguard Worker 
138*da0073e9SAndroid Build Coastguard Worker  private:
139*da0073e9SAndroid Build Coastguard Worker   Stream stream_;
140*da0073e9SAndroid Build Coastguard Worker };
141*da0073e9SAndroid Build Coastguard Worker 
142*da0073e9SAndroid Build Coastguard Worker /**
143*da0073e9SAndroid Build Coastguard Worker  * Get a stream from the pool in a round-robin fashion.
144*da0073e9SAndroid Build Coastguard Worker  *
145*da0073e9SAndroid Build Coastguard Worker  * You can request a stream from the highest priority pool by setting
146*da0073e9SAndroid Build Coastguard Worker  * isHighPriority to true for a specific device.
147*da0073e9SAndroid Build Coastguard Worker  */
148*da0073e9SAndroid Build Coastguard Worker C10_XPU_API XPUStream
149*da0073e9SAndroid Build Coastguard Worker getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
150*da0073e9SAndroid Build Coastguard Worker 
151*da0073e9SAndroid Build Coastguard Worker /**
152*da0073e9SAndroid Build Coastguard Worker  * Get a stream from the pool in a round-robin fashion.
153*da0073e9SAndroid Build Coastguard Worker  *
154*da0073e9SAndroid Build Coastguard Worker  * You can request a stream by setting a priority value for a specific device.
155*da0073e9SAndroid Build Coastguard Worker  * The priority number lower, the priority higher.
156*da0073e9SAndroid Build Coastguard Worker  */
157*da0073e9SAndroid Build Coastguard Worker C10_XPU_API XPUStream
158*da0073e9SAndroid Build Coastguard Worker getStreamFromPool(const int priority, DeviceIndex device = -1);
159*da0073e9SAndroid Build Coastguard Worker 
160*da0073e9SAndroid Build Coastguard Worker /**
161*da0073e9SAndroid Build Coastguard Worker  * Get the current XPU stream, for the passed XPU device, or for the current
162*da0073e9SAndroid Build Coastguard Worker  * device if no device index is passed.
163*da0073e9SAndroid Build Coastguard Worker  */
164*da0073e9SAndroid Build Coastguard Worker C10_XPU_API XPUStream getCurrentXPUStream(DeviceIndex device = -1);
165*da0073e9SAndroid Build Coastguard Worker 
166*da0073e9SAndroid Build Coastguard Worker /**
167*da0073e9SAndroid Build Coastguard Worker  * Set the current stream on the device of the passed in stream to be the passed
168*da0073e9SAndroid Build Coastguard Worker  * in stream.
169*da0073e9SAndroid Build Coastguard Worker  */
170*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void setCurrentXPUStream(XPUStream stream);
171*da0073e9SAndroid Build Coastguard Worker 
172*da0073e9SAndroid Build Coastguard Worker C10_XPU_API std::ostream& operator<<(std::ostream& stream, const XPUStream& s);
173*da0073e9SAndroid Build Coastguard Worker 
174*da0073e9SAndroid Build Coastguard Worker /**
175*da0073e9SAndroid Build Coastguard Worker  * Block all reserved SYCL queues in the stream pools on the device, and wait
176*da0073e9SAndroid Build Coastguard Worker  * for their synchronizations.
177*da0073e9SAndroid Build Coastguard Worker  */
178*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void syncStreamsOnDevice(DeviceIndex device = -1);
179*da0073e9SAndroid Build Coastguard Worker 
180*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu
181*da0073e9SAndroid Build Coastguard Worker 
182*da0073e9SAndroid Build Coastguard Worker namespace std {
183*da0073e9SAndroid Build Coastguard Worker template <>
184*da0073e9SAndroid Build Coastguard Worker struct hash<c10::xpu::XPUStream> {
185*da0073e9SAndroid Build Coastguard Worker   size_t operator()(c10::xpu::XPUStream s) const noexcept {
186*da0073e9SAndroid Build Coastguard Worker     return std::hash<c10::Stream>{}(s.unwrap());
187*da0073e9SAndroid Build Coastguard Worker   }
188*da0073e9SAndroid Build Coastguard Worker };
189*da0073e9SAndroid Build Coastguard Worker } // namespace std
190