1 #pragma once 2 3 #include <c10/core/Stream.h> 4 #include <c10/core/impl/GPUTrace.h> 5 #include <c10/xpu/XPUFunctions.h> 6 7 namespace c10::xpu { 8 9 /* 10 * Note [Stream Management] 11 * 12 * An XPUStream is an abstraction of an actual SYCL queue in which SYCL kernel 13 * can execute. Currently, there are several pools per device to manage SYCL 14 * queue, and a device's pool is lazily created. 15 * 16 * There are two pools per device. The first pool contains "normal priority" 17 * queues. The second pool is the "high priority" queues. There are 32 queues in 18 * per pool per device, and when a queue is requested one of these queues is 19 * returned round-robin. That is, the first queue requested is at index 0, the 20 * second at index 1... to index 31, then index 0 again. 21 * 22 * This means that if 33 queues are requested, the first and last queues 23 * requested are actually the same queue (under the covers) and kernels enqueued 24 * on them cannot run concurrently. 25 * 26 * It is safe to enqueue a kernel on the same queue from two different 27 * threads as the SYCL specification described. 28 */ 29 30 static constexpr int max_compile_time_stream_priorities = 2; 31 32 /* 33 * This serves as a wrapper around c10::Stream and acts as a representation for 34 * a SYCL queue, which allows asynchronous execution of XPU tasks. 35 */ 36 class C10_XPU_API XPUStream { 37 public: 38 enum Unchecked { UNCHECKED }; 39 40 /// Construct a XPUStream from a Stream. This construction is checked, and 41 /// will raise an error if the Stream is not, in fact, a XPU stream. XPUStream(Stream stream)42 explicit XPUStream(Stream stream) : stream_(stream) { 43 TORCH_CHECK(stream_.device_type() == DeviceType::XPU); 44 } 45 46 /// Construct a XPUStream from a Stream with no error checking. XPUStream(Unchecked,Stream stream)47 explicit XPUStream(Unchecked, Stream stream) : stream_(stream) {} 48 49 bool operator==(const XPUStream& other) const noexcept { 50 return unwrap() == other.unwrap(); 51 } 52 53 bool operator!=(const XPUStream& other) const noexcept { 54 return unwrap() != other.unwrap(); 55 } 56 57 /// Implicit conversion to sycl::queue&. 58 operator sycl::queue&() const { 59 return queue(); 60 } 61 62 /// Implicit conversion to Stream (a.k.a., forget that the stream is a 63 /// XPU stream). Stream()64 operator Stream() const { 65 return unwrap(); 66 } 67 68 /// Get the XPU device type that this stream is associated with. device_type()69 DeviceType device_type() const { 70 return DeviceType::XPU; 71 } 72 73 /// Get the XPU device index that this stream is associated with. device_index()74 DeviceIndex device_index() const { 75 return stream_.device_index(); 76 } 77 78 /// Get the full Device that this stream is associated with. The Device is 79 /// guaranteed to be a XPU device. device()80 Device device() const { 81 return Device(DeviceType::XPU, device_index()); 82 } 83 84 /// Return the stream ID corresponding to this particular stream. StreamId is 85 /// a int64_t representation generated by its type and index. id()86 StreamId id() const { 87 return stream_.id(); 88 } 89 90 /// Return true if all enqueued tasks in this stream have been completed, 91 /// otherwise return false. query()92 bool query() const { 93 return queue().ext_oneapi_empty(); 94 } 95 96 /// Performs a blocking wait for the completion of all enqueued tasks in this 97 /// stream. synchronize()98 void synchronize() const { 99 queue().wait_and_throw(); 100 const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 101 if (C10_UNLIKELY(interp)) { 102 (*interp)->trace_gpu_stream_synchronization( 103 c10::kXPU, reinterpret_cast<uintptr_t>(&queue())); 104 } 105 } 106 107 /// Return the priority that this stream is associated with. Lower numbers 108 /// represent higher priority. 109 int priority() const; 110 111 /// Explicit conversion to sycl::queue&. 112 sycl::queue& queue() const; 113 114 /// Explicit conversion to Stream. unwrap()115 Stream unwrap() const { 116 return stream_; 117 } 118 119 /// Reversibly pack a XPUStream into a struct representation. The XPUStream 120 /// can be unpacked using unpack3(). pack3()121 struct c10::StreamData3 pack3() const { 122 return stream_.pack3(); 123 } 124 125 /// Unpack a XPUStream from the 3 fields generated by pack3(). unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)126 static XPUStream unpack3( 127 StreamId stream_id, 128 DeviceIndex device_index, 129 DeviceType device_type) { 130 return XPUStream(Stream::unpack3(stream_id, device_index, device_type)); 131 } 132 133 /// Return the range of priority **supported by PyTorch**. priority_range()134 static std::tuple<int, int> priority_range() { 135 return std::make_tuple(0, -max_compile_time_stream_priorities + 1); 136 } 137 138 private: 139 Stream stream_; 140 }; 141 142 /** 143 * Get a stream from the pool in a round-robin fashion. 144 * 145 * You can request a stream from the highest priority pool by setting 146 * isHighPriority to true for a specific device. 147 */ 148 C10_XPU_API XPUStream 149 getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); 150 151 /** 152 * Get a stream from the pool in a round-robin fashion. 153 * 154 * You can request a stream by setting a priority value for a specific device. 155 * The priority number lower, the priority higher. 156 */ 157 C10_XPU_API XPUStream 158 getStreamFromPool(const int priority, DeviceIndex device = -1); 159 160 /** 161 * Get the current XPU stream, for the passed XPU device, or for the current 162 * device if no device index is passed. 163 */ 164 C10_XPU_API XPUStream getCurrentXPUStream(DeviceIndex device = -1); 165 166 /** 167 * Set the current stream on the device of the passed in stream to be the passed 168 * in stream. 169 */ 170 C10_XPU_API void setCurrentXPUStream(XPUStream stream); 171 172 C10_XPU_API std::ostream& operator<<(std::ostream& stream, const XPUStream& s); 173 174 /** 175 * Block all reserved SYCL queues in the stream pools on the device, and wait 176 * for their synchronizations. 177 */ 178 C10_XPU_API void syncStreamsOnDevice(DeviceIndex device = -1); 179 180 } // namespace c10::xpu 181 182 namespace std { 183 template <> 184 struct hash<c10::xpu::XPUStream> { 185 size_t operator()(c10::xpu::XPUStream s) const noexcept { 186 return std::hash<c10::Stream>{}(s.unwrap()); 187 } 188 }; 189 } // namespace std 190