1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Stream.h> 4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/GPUTrace.h> 5*da0073e9SAndroid Build Coastguard Worker #include <c10/xpu/XPUFunctions.h> 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker namespace c10::xpu { 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker /* 10*da0073e9SAndroid Build Coastguard Worker * Note [Stream Management] 11*da0073e9SAndroid Build Coastguard Worker * 12*da0073e9SAndroid Build Coastguard Worker * An XPUStream is an abstraction of an actual SYCL queue in which SYCL kernel 13*da0073e9SAndroid Build Coastguard Worker * can execute. Currently, there are several pools per device to manage SYCL 14*da0073e9SAndroid Build Coastguard Worker * queue, and a device's pool is lazily created. 15*da0073e9SAndroid Build Coastguard Worker * 16*da0073e9SAndroid Build Coastguard Worker * There are two pools per device. The first pool contains "normal priority" 17*da0073e9SAndroid Build Coastguard Worker * queues. The second pool is the "high priority" queues. There are 32 queues in 18*da0073e9SAndroid Build Coastguard Worker * per pool per device, and when a queue is requested one of these queues is 19*da0073e9SAndroid Build Coastguard Worker * returned round-robin. That is, the first queue requested is at index 0, the 20*da0073e9SAndroid Build Coastguard Worker * second at index 1... to index 31, then index 0 again. 21*da0073e9SAndroid Build Coastguard Worker * 22*da0073e9SAndroid Build Coastguard Worker * This means that if 33 queues are requested, the first and last queues 23*da0073e9SAndroid Build Coastguard Worker * requested are actually the same queue (under the covers) and kernels enqueued 24*da0073e9SAndroid Build Coastguard Worker * on them cannot run concurrently. 25*da0073e9SAndroid Build Coastguard Worker * 26*da0073e9SAndroid Build Coastguard Worker * It is safe to enqueue a kernel on the same queue from two different 27*da0073e9SAndroid Build Coastguard Worker * threads as the SYCL specification described. 28*da0073e9SAndroid Build Coastguard Worker */ 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker static constexpr int max_compile_time_stream_priorities = 2; 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker /* 33*da0073e9SAndroid Build Coastguard Worker * This serves as a wrapper around c10::Stream and acts as a representation for 34*da0073e9SAndroid Build Coastguard Worker * a SYCL queue, which allows asynchronous execution of XPU tasks. 35*da0073e9SAndroid Build Coastguard Worker */ 36*da0073e9SAndroid Build Coastguard Worker class C10_XPU_API XPUStream { 37*da0073e9SAndroid Build Coastguard Worker public: 38*da0073e9SAndroid Build Coastguard Worker enum Unchecked { UNCHECKED }; 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker /// Construct a XPUStream from a Stream. This construction is checked, and 41*da0073e9SAndroid Build Coastguard Worker /// will raise an error if the Stream is not, in fact, a XPU stream. XPUStream(Stream stream)42*da0073e9SAndroid Build Coastguard Worker explicit XPUStream(Stream stream) : stream_(stream) { 43*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(stream_.device_type() == DeviceType::XPU); 44*da0073e9SAndroid Build Coastguard Worker } 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker /// Construct a XPUStream from a Stream with no error checking. XPUStream(Unchecked,Stream stream)47*da0073e9SAndroid Build Coastguard Worker explicit XPUStream(Unchecked, Stream stream) : stream_(stream) {} 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker bool operator==(const XPUStream& other) const noexcept { 50*da0073e9SAndroid Build Coastguard Worker return unwrap() == other.unwrap(); 51*da0073e9SAndroid Build Coastguard Worker } 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker bool operator!=(const XPUStream& other) const noexcept { 54*da0073e9SAndroid Build Coastguard Worker return unwrap() != other.unwrap(); 55*da0073e9SAndroid Build Coastguard Worker } 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker /// Implicit conversion to sycl::queue&. 58*da0073e9SAndroid Build Coastguard Worker operator sycl::queue&() const { 59*da0073e9SAndroid Build Coastguard Worker return queue(); 60*da0073e9SAndroid Build Coastguard Worker } 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker /// Implicit conversion to Stream (a.k.a., forget that the stream is a 63*da0073e9SAndroid Build Coastguard Worker /// XPU stream). Stream()64*da0073e9SAndroid Build Coastguard Worker operator Stream() const { 65*da0073e9SAndroid Build Coastguard Worker return unwrap(); 66*da0073e9SAndroid Build Coastguard Worker } 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker /// Get the XPU device type that this stream is associated with. device_type()69*da0073e9SAndroid Build Coastguard Worker DeviceType device_type() const { 70*da0073e9SAndroid Build Coastguard Worker return DeviceType::XPU; 71*da0073e9SAndroid Build Coastguard Worker } 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker /// Get the XPU device index that this stream is associated with. device_index()74*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_index() const { 75*da0073e9SAndroid Build Coastguard Worker return stream_.device_index(); 76*da0073e9SAndroid Build Coastguard Worker } 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker /// Get the full Device that this stream is associated with. The Device is 79*da0073e9SAndroid Build Coastguard Worker /// guaranteed to be a XPU device. device()80*da0073e9SAndroid Build Coastguard Worker Device device() const { 81*da0073e9SAndroid Build Coastguard Worker return Device(DeviceType::XPU, device_index()); 82*da0073e9SAndroid Build Coastguard Worker } 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker /// Return the stream ID corresponding to this particular stream. StreamId is 85*da0073e9SAndroid Build Coastguard Worker /// a int64_t representation generated by its type and index. id()86*da0073e9SAndroid Build Coastguard Worker StreamId id() const { 87*da0073e9SAndroid Build Coastguard Worker return stream_.id(); 88*da0073e9SAndroid Build Coastguard Worker } 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker /// Return true if all enqueued tasks in this stream have been completed, 91*da0073e9SAndroid Build Coastguard Worker /// otherwise return false. query()92*da0073e9SAndroid Build Coastguard Worker bool query() const { 93*da0073e9SAndroid Build Coastguard Worker return queue().ext_oneapi_empty(); 94*da0073e9SAndroid Build Coastguard Worker } 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker /// Performs a blocking wait for the completion of all enqueued tasks in this 97*da0073e9SAndroid Build Coastguard Worker /// stream. synchronize()98*da0073e9SAndroid Build Coastguard Worker void synchronize() const { 99*da0073e9SAndroid Build Coastguard Worker queue().wait_and_throw(); 100*da0073e9SAndroid Build Coastguard Worker const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); 101*da0073e9SAndroid Build Coastguard Worker if (C10_UNLIKELY(interp)) { 102*da0073e9SAndroid Build Coastguard Worker (*interp)->trace_gpu_stream_synchronization( 103*da0073e9SAndroid Build Coastguard Worker c10::kXPU, reinterpret_cast<uintptr_t>(&queue())); 104*da0073e9SAndroid Build Coastguard Worker } 105*da0073e9SAndroid Build Coastguard Worker } 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker /// Return the priority that this stream is associated with. Lower numbers 108*da0073e9SAndroid Build Coastguard Worker /// represent higher priority. 109*da0073e9SAndroid Build Coastguard Worker int priority() const; 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker /// Explicit conversion to sycl::queue&. 112*da0073e9SAndroid Build Coastguard Worker sycl::queue& queue() const; 113*da0073e9SAndroid Build Coastguard Worker 114*da0073e9SAndroid Build Coastguard Worker /// Explicit conversion to Stream. unwrap()115*da0073e9SAndroid Build Coastguard Worker Stream unwrap() const { 116*da0073e9SAndroid Build Coastguard Worker return stream_; 117*da0073e9SAndroid Build Coastguard Worker } 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker /// Reversibly pack a XPUStream into a struct representation. The XPUStream 120*da0073e9SAndroid Build Coastguard Worker /// can be unpacked using unpack3(). pack3()121*da0073e9SAndroid Build Coastguard Worker struct c10::StreamData3 pack3() const { 122*da0073e9SAndroid Build Coastguard Worker return stream_.pack3(); 123*da0073e9SAndroid Build Coastguard Worker } 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker /// Unpack a XPUStream from the 3 fields generated by pack3(). unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)126*da0073e9SAndroid Build Coastguard Worker static XPUStream unpack3( 127*da0073e9SAndroid Build Coastguard Worker StreamId stream_id, 128*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_index, 129*da0073e9SAndroid Build Coastguard Worker DeviceType device_type) { 130*da0073e9SAndroid Build Coastguard Worker return XPUStream(Stream::unpack3(stream_id, device_index, device_type)); 131*da0073e9SAndroid Build Coastguard Worker } 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker /// Return the range of priority **supported by PyTorch**. priority_range()134*da0073e9SAndroid Build Coastguard Worker static std::tuple<int, int> priority_range() { 135*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(0, -max_compile_time_stream_priorities + 1); 136*da0073e9SAndroid Build Coastguard Worker } 137*da0073e9SAndroid Build Coastguard Worker 138*da0073e9SAndroid Build Coastguard Worker private: 139*da0073e9SAndroid Build Coastguard Worker Stream stream_; 140*da0073e9SAndroid Build Coastguard Worker }; 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker /** 143*da0073e9SAndroid Build Coastguard Worker * Get a stream from the pool in a round-robin fashion. 144*da0073e9SAndroid Build Coastguard Worker * 145*da0073e9SAndroid Build Coastguard Worker * You can request a stream from the highest priority pool by setting 146*da0073e9SAndroid Build Coastguard Worker * isHighPriority to true for a specific device. 147*da0073e9SAndroid Build Coastguard Worker */ 148*da0073e9SAndroid Build Coastguard Worker C10_XPU_API XPUStream 149*da0073e9SAndroid Build Coastguard Worker getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker /** 152*da0073e9SAndroid Build Coastguard Worker * Get a stream from the pool in a round-robin fashion. 153*da0073e9SAndroid Build Coastguard Worker * 154*da0073e9SAndroid Build Coastguard Worker * You can request a stream by setting a priority value for a specific device. 155*da0073e9SAndroid Build Coastguard Worker * The priority number lower, the priority higher. 156*da0073e9SAndroid Build Coastguard Worker */ 157*da0073e9SAndroid Build Coastguard Worker C10_XPU_API XPUStream 158*da0073e9SAndroid Build Coastguard Worker getStreamFromPool(const int priority, DeviceIndex device = -1); 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker /** 161*da0073e9SAndroid Build Coastguard Worker * Get the current XPU stream, for the passed XPU device, or for the current 162*da0073e9SAndroid Build Coastguard Worker * device if no device index is passed. 163*da0073e9SAndroid Build Coastguard Worker */ 164*da0073e9SAndroid Build Coastguard Worker C10_XPU_API XPUStream getCurrentXPUStream(DeviceIndex device = -1); 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker /** 167*da0073e9SAndroid Build Coastguard Worker * Set the current stream on the device of the passed in stream to be the passed 168*da0073e9SAndroid Build Coastguard Worker * in stream. 169*da0073e9SAndroid Build Coastguard Worker */ 170*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void setCurrentXPUStream(XPUStream stream); 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker C10_XPU_API std::ostream& operator<<(std::ostream& stream, const XPUStream& s); 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker /** 175*da0073e9SAndroid Build Coastguard Worker * Block all reserved SYCL queues in the stream pools on the device, and wait 176*da0073e9SAndroid Build Coastguard Worker * for their synchronizations. 177*da0073e9SAndroid Build Coastguard Worker */ 178*da0073e9SAndroid Build Coastguard Worker C10_XPU_API void syncStreamsOnDevice(DeviceIndex device = -1); 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker } // namespace c10::xpu 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker namespace std { 183*da0073e9SAndroid Build Coastguard Worker template <> 184*da0073e9SAndroid Build Coastguard Worker struct hash<c10::xpu::XPUStream> { 185*da0073e9SAndroid Build Coastguard Worker size_t operator()(c10::xpu::XPUStream s) const noexcept { 186*da0073e9SAndroid Build Coastguard Worker return std::hash<c10::Stream>{}(s.unwrap()); 187*da0073e9SAndroid Build Coastguard Worker } 188*da0073e9SAndroid Build Coastguard Worker }; 189*da0073e9SAndroid Build Coastguard Worker } // namespace std 190