1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <cuda_runtime_api.h> 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DeviceGuard.h> 6*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Stream.h> 7*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAFunctions.h> 8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h> 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker /* 11*da0073e9SAndroid Build Coastguard Worker * Stream pool note. 12*da0073e9SAndroid Build Coastguard Worker * 13*da0073e9SAndroid Build Coastguard Worker * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams 14*da0073e9SAndroid Build Coastguard Worker * are backed by cuStreams, but they use several pools to minimize the costs 15*da0073e9SAndroid Build Coastguard Worker * associated with creating, retaining, and destroying cuStreams. 16*da0073e9SAndroid Build Coastguard Worker * 17*da0073e9SAndroid Build Coastguard Worker * There are three pools per device, and a device's pools are lazily created. 18*da0073e9SAndroid Build Coastguard Worker * 19*da0073e9SAndroid Build Coastguard Worker * The first pool contains only the default stream. When the default stream 20*da0073e9SAndroid Build Coastguard Worker * is requested it's returned. 21*da0073e9SAndroid Build Coastguard Worker * 22*da0073e9SAndroid Build Coastguard Worker * The second pool is the "low priority" or "default priority" streams. In 23*da0073e9SAndroid Build Coastguard Worker * HIP builds there is no distinction between streams in this pool and streams 24*da0073e9SAndroid Build Coastguard Worker * in the third pool (below). There are 32 of these streams per device, and 25*da0073e9SAndroid Build Coastguard Worker * when a stream is requested one of these streams is returned round-robin. 26*da0073e9SAndroid Build Coastguard Worker * That is, the first stream requested is at index 0, the second at index 1... 27*da0073e9SAndroid Build Coastguard Worker * to index 31, then index 0 again. 28*da0073e9SAndroid Build Coastguard Worker * 29*da0073e9SAndroid Build Coastguard Worker * This means that if 33 low priority streams are requested, the first and 30*da0073e9SAndroid Build Coastguard Worker * last streams requested are actually the same stream (under the covers) 31*da0073e9SAndroid Build Coastguard Worker * and kernels enqueued on them cannot run concurrently. 32*da0073e9SAndroid Build Coastguard Worker * 33*da0073e9SAndroid Build Coastguard Worker * The third pool is the "high priority" streams. The third pool acts like 34*da0073e9SAndroid Build Coastguard Worker * the second pool except the streams are created with a higher priority. 35*da0073e9SAndroid Build Coastguard Worker * 36*da0073e9SAndroid Build Coastguard Worker * These pools suggest that stream users should prefer many short-lived streams, 37*da0073e9SAndroid Build Coastguard Worker * as the cost of acquiring and releasing streams is effectively zero. If 38*da0073e9SAndroid Build Coastguard Worker * many longer-lived streams are required in performance critical scenarios 39*da0073e9SAndroid Build Coastguard Worker * then the functionality here may need to be extended to allow, for example, 40*da0073e9SAndroid Build Coastguard Worker * "reserving" a subset of the pool so that other streams do not accidentally 41*da0073e9SAndroid Build Coastguard Worker * overlap the performance critical streams. 42*da0073e9SAndroid Build Coastguard Worker * 43*da0073e9SAndroid Build Coastguard Worker * Note: although the notion of "current stream for device" is thread local 44*da0073e9SAndroid Build Coastguard Worker * (every OS thread has a separate current stream, as one might expect), 45*da0073e9SAndroid Build Coastguard Worker * the stream pool is global across all threads; stream 0 is always stream 0 46*da0073e9SAndroid Build Coastguard Worker * no matter which thread you use it on. Multiple threads can synchronize 47*da0073e9SAndroid Build Coastguard Worker * on the same stream. Although the CUDA documentation is not very clear 48*da0073e9SAndroid Build Coastguard Worker * on the matter, streams are thread safe; e.g., it is safe to enqueue 49*da0073e9SAndroid Build Coastguard Worker * a kernel on the same stream from two different threads. 50*da0073e9SAndroid Build Coastguard Worker */ 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda { 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker static constexpr int max_compile_time_stream_priorities = 4; 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker // Value object representing a CUDA stream. This is just a wrapper 57*da0073e9SAndroid Build Coastguard Worker // around c10::Stream, but it comes with a little extra CUDA-specific 58*da0073e9SAndroid Build Coastguard Worker // functionality (conversion to cudaStream_t), and a guarantee that 59*da0073e9SAndroid Build Coastguard Worker // the wrapped c10::Stream really is a CUDA stream. 60*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API CUDAStream { 61*da0073e9SAndroid Build Coastguard Worker public: 62*da0073e9SAndroid Build Coastguard Worker enum Unchecked { UNCHECKED }; 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker /// Construct a CUDAStream from a Stream. This construction is checked, 65*da0073e9SAndroid Build Coastguard Worker /// and will raise an error if the Stream is not, in fact, a CUDA stream. CUDAStream(Stream stream)66*da0073e9SAndroid Build Coastguard Worker explicit CUDAStream(Stream stream) : stream_(stream) { 67*da0073e9SAndroid Build Coastguard Worker TORCH_CHECK(stream_.device_type() == DeviceType::CUDA); 68*da0073e9SAndroid Build Coastguard Worker } 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker /// Construct a CUDAStream from a Stream with no error checking. 71*da0073e9SAndroid Build Coastguard Worker /// This constructor uses the "named" constructor idiom, and can 72*da0073e9SAndroid Build Coastguard Worker /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream) CUDAStream(Unchecked,Stream stream)73*da0073e9SAndroid Build Coastguard Worker explicit CUDAStream(Unchecked, Stream stream) : stream_(stream) {} 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker bool operator==(const CUDAStream& other) const noexcept { 76*da0073e9SAndroid Build Coastguard Worker return unwrap() == other.unwrap(); 77*da0073e9SAndroid Build Coastguard Worker } 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker bool operator!=(const CUDAStream& other) const noexcept { 80*da0073e9SAndroid Build Coastguard Worker return unwrap() != other.unwrap(); 81*da0073e9SAndroid Build Coastguard Worker } 82*da0073e9SAndroid Build Coastguard Worker 83*da0073e9SAndroid Build Coastguard Worker /// Implicit conversion to cudaStream_t. cudaStream_t()84*da0073e9SAndroid Build Coastguard Worker operator cudaStream_t() const { 85*da0073e9SAndroid Build Coastguard Worker return stream(); 86*da0073e9SAndroid Build Coastguard Worker } 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker /// Implicit conversion to Stream (a.k.a., forget that the stream is a 89*da0073e9SAndroid Build Coastguard Worker /// CUDA stream). Stream()90*da0073e9SAndroid Build Coastguard Worker operator Stream() const { 91*da0073e9SAndroid Build Coastguard Worker return unwrap(); 92*da0073e9SAndroid Build Coastguard Worker } 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker /// Used to avoid baking in device type explicitly to Python-side API. device_type()95*da0073e9SAndroid Build Coastguard Worker DeviceType device_type() const { 96*da0073e9SAndroid Build Coastguard Worker return DeviceType::CUDA; 97*da0073e9SAndroid Build Coastguard Worker } 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker /// Get the CUDA device index that this stream is associated with. device_index()100*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_index() const { 101*da0073e9SAndroid Build Coastguard Worker return stream_.device_index(); 102*da0073e9SAndroid Build Coastguard Worker } 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker /// Get the full Device that this stream is associated with. The Device 105*da0073e9SAndroid Build Coastguard Worker /// is guaranteed to be a CUDA device. device()106*da0073e9SAndroid Build Coastguard Worker Device device() const { 107*da0073e9SAndroid Build Coastguard Worker return Device(DeviceType::CUDA, device_index()); 108*da0073e9SAndroid Build Coastguard Worker } 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker /// Return the stream ID corresponding to this particular stream. id()111*da0073e9SAndroid Build Coastguard Worker StreamId id() const { 112*da0073e9SAndroid Build Coastguard Worker return stream_.id(); 113*da0073e9SAndroid Build Coastguard Worker } 114*da0073e9SAndroid Build Coastguard Worker query()115*da0073e9SAndroid Build Coastguard Worker bool query() const { 116*da0073e9SAndroid Build Coastguard Worker DeviceGuard guard{stream_.device()}; 117*da0073e9SAndroid Build Coastguard Worker cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream())); 118*da0073e9SAndroid Build Coastguard Worker 119*da0073e9SAndroid Build Coastguard Worker if (err == cudaSuccess) { 120*da0073e9SAndroid Build Coastguard Worker return true; 121*da0073e9SAndroid Build Coastguard Worker } else if (err != cudaErrorNotReady) { 122*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(err); 123*da0073e9SAndroid Build Coastguard Worker } else { 124*da0073e9SAndroid Build Coastguard Worker // ignore and clear the error if not ready 125*da0073e9SAndroid Build Coastguard Worker (void)cudaGetLastError(); 126*da0073e9SAndroid Build Coastguard Worker } 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker return false; 129*da0073e9SAndroid Build Coastguard Worker } 130*da0073e9SAndroid Build Coastguard Worker synchronize()131*da0073e9SAndroid Build Coastguard Worker void synchronize() const { 132*da0073e9SAndroid Build Coastguard Worker DeviceGuard guard{stream_.device()}; 133*da0073e9SAndroid Build Coastguard Worker c10::cuda::stream_synchronize(stream()); 134*da0073e9SAndroid Build Coastguard Worker } 135*da0073e9SAndroid Build Coastguard Worker priority()136*da0073e9SAndroid Build Coastguard Worker int priority() const { 137*da0073e9SAndroid Build Coastguard Worker DeviceGuard guard{stream_.device()}; 138*da0073e9SAndroid Build Coastguard Worker int priority = 0; 139*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority)); 140*da0073e9SAndroid Build Coastguard Worker return priority; 141*da0073e9SAndroid Build Coastguard Worker } 142*da0073e9SAndroid Build Coastguard Worker 143*da0073e9SAndroid Build Coastguard Worker /// Explicit conversion to cudaStream_t. 144*da0073e9SAndroid Build Coastguard Worker cudaStream_t stream() const; 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker /// Explicit conversion to Stream. unwrap()147*da0073e9SAndroid Build Coastguard Worker Stream unwrap() const { 148*da0073e9SAndroid Build Coastguard Worker return stream_; 149*da0073e9SAndroid Build Coastguard Worker } 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker /// Reversibly pack a CUDAStream into a struct representation. 152*da0073e9SAndroid Build Coastguard Worker /// Previously the stream's data was packed into a single int64_t, 153*da0073e9SAndroid Build Coastguard Worker /// as it was assumed the fields would not require more than 154*da0073e9SAndroid Build Coastguard Worker /// 64 bits of storage in total. 155*da0073e9SAndroid Build Coastguard Worker /// See https://github.com/pytorch/pytorch/issues/75854 156*da0073e9SAndroid Build Coastguard Worker /// for more information regarding newer platforms that may violate 157*da0073e9SAndroid Build Coastguard Worker /// this assumption. 158*da0073e9SAndroid Build Coastguard Worker /// 159*da0073e9SAndroid Build Coastguard Worker /// The CUDAStream can be unpacked using unpack(). pack3()160*da0073e9SAndroid Build Coastguard Worker struct c10::StreamData3 pack3() const { 161*da0073e9SAndroid Build Coastguard Worker return stream_.pack3(); 162*da0073e9SAndroid Build Coastguard Worker } 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker // Unpack a CUDAStream from the 3 fields generated by pack(). unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)165*da0073e9SAndroid Build Coastguard Worker static CUDAStream unpack3( 166*da0073e9SAndroid Build Coastguard Worker StreamId stream_id, 167*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_index, 168*da0073e9SAndroid Build Coastguard Worker DeviceType device_type) { 169*da0073e9SAndroid Build Coastguard Worker return CUDAStream(Stream::unpack3(stream_id, device_index, device_type)); 170*da0073e9SAndroid Build Coastguard Worker } 171*da0073e9SAndroid Build Coastguard Worker priority_range()172*da0073e9SAndroid Build Coastguard Worker static std::tuple<int, int> priority_range() { 173*da0073e9SAndroid Build Coastguard Worker // Note: this returns the range of priority **supported by PyTorch**, not 174*da0073e9SAndroid Build Coastguard Worker // the range of priority **supported by CUDA**. The former is a subset of 175*da0073e9SAndroid Build Coastguard Worker // the latter. 176*da0073e9SAndroid Build Coastguard Worker int least_priority = 0, greatest_priority = 0; 177*da0073e9SAndroid Build Coastguard Worker C10_CUDA_CHECK( 178*da0073e9SAndroid Build Coastguard Worker cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority)); 179*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM 180*da0073e9SAndroid Build Coastguard Worker // See Note [HIP stream priorities] 181*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT( 182*da0073e9SAndroid Build Coastguard Worker least_priority == 1, "Unexpected HIP stream priority range"); 183*da0073e9SAndroid Build Coastguard Worker least_priority = 0; 184*da0073e9SAndroid Build Coastguard Worker #else 185*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT( 186*da0073e9SAndroid Build Coastguard Worker least_priority == 0, "Unexpected CUDA stream priority range"); 187*da0073e9SAndroid Build Coastguard Worker #endif 188*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT( 189*da0073e9SAndroid Build Coastguard Worker greatest_priority <= -1, "Unexpected CUDA stream priority range"); 190*da0073e9SAndroid Build Coastguard Worker greatest_priority = std::max( 191*da0073e9SAndroid Build Coastguard Worker -c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority); 192*da0073e9SAndroid Build Coastguard Worker return std::make_tuple(least_priority, greatest_priority); 193*da0073e9SAndroid Build Coastguard Worker } 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Worker // Deleted for now; use CUDAEvent::block instead 196*da0073e9SAndroid Build Coastguard Worker // void synchronize_with(const CUDAEvent& event) const; 197*da0073e9SAndroid Build Coastguard Worker 198*da0073e9SAndroid Build Coastguard Worker private: 199*da0073e9SAndroid Build Coastguard Worker Stream stream_; 200*da0073e9SAndroid Build Coastguard Worker }; 201*da0073e9SAndroid Build Coastguard Worker 202*da0073e9SAndroid Build Coastguard Worker /** 203*da0073e9SAndroid Build Coastguard Worker * Get a new stream from the CUDA stream pool. You can think of this 204*da0073e9SAndroid Build Coastguard Worker * as "creating" a new stream, but no such creation actually happens; 205*da0073e9SAndroid Build Coastguard Worker * instead, streams are preallocated from the pool and returned in a 206*da0073e9SAndroid Build Coastguard Worker * round-robin fashion. 207*da0073e9SAndroid Build Coastguard Worker * 208*da0073e9SAndroid Build Coastguard Worker * You can request a stream from the high priority pool by setting 209*da0073e9SAndroid Build Coastguard Worker * isHighPriority to true, or a stream for a specific device by setting device 210*da0073e9SAndroid Build Coastguard Worker * (defaulting to the current CUDA stream.) 211*da0073e9SAndroid Build Coastguard Worker */ 212*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream 213*da0073e9SAndroid Build Coastguard Worker getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); 214*da0073e9SAndroid Build Coastguard Worker // no default priority to disambiguate overloads 215*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream 216*da0073e9SAndroid Build Coastguard Worker getStreamFromPool(const int priority, DeviceIndex device = -1); 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker /** 219*da0073e9SAndroid Build Coastguard Worker * Get a CUDAStream from a externally allocated one. 220*da0073e9SAndroid Build Coastguard Worker * 221*da0073e9SAndroid Build Coastguard Worker * This is mainly for interoperability with different libraries where we 222*da0073e9SAndroid Build Coastguard Worker * want to operate on a non-torch allocated stream for data exchange or similar 223*da0073e9SAndroid Build Coastguard Worker * purposes 224*da0073e9SAndroid Build Coastguard Worker */ 225*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream 226*da0073e9SAndroid Build Coastguard Worker getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index); 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker /** 229*da0073e9SAndroid Build Coastguard Worker * Get the default CUDA stream, for the passed CUDA device, or for the 230*da0073e9SAndroid Build Coastguard Worker * current device if no device index is passed. The default stream is 231*da0073e9SAndroid Build Coastguard Worker * where most computation occurs when you aren't explicitly using 232*da0073e9SAndroid Build Coastguard Worker * streams. 233*da0073e9SAndroid Build Coastguard Worker */ 234*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1); 235*da0073e9SAndroid Build Coastguard Worker 236*da0073e9SAndroid Build Coastguard Worker /** 237*da0073e9SAndroid Build Coastguard Worker * Get the current CUDA stream, for the passed CUDA device, or for the 238*da0073e9SAndroid Build Coastguard Worker * current device if no device index is passed. The current CUDA stream 239*da0073e9SAndroid Build Coastguard Worker * will usually be the default CUDA stream for the device, but it may 240*da0073e9SAndroid Build Coastguard Worker * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard' 241*da0073e9SAndroid Build Coastguard Worker * or 'CUDAStreamGuard'. 242*da0073e9SAndroid Build Coastguard Worker */ 243*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1); 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker /** 246*da0073e9SAndroid Build Coastguard Worker * Set the current stream on the device of the passed in stream to be 247*da0073e9SAndroid Build Coastguard Worker * the passed in stream. Yes, you read that right: this function 248*da0073e9SAndroid Build Coastguard Worker * has *nothing* to do with the current device: it toggles the current 249*da0073e9SAndroid Build Coastguard Worker * stream of the device of the passed stream. 250*da0073e9SAndroid Build Coastguard Worker * 251*da0073e9SAndroid Build Coastguard Worker * Confused? Avoid using this function; prefer using 'CUDAStreamGuard' instead 252*da0073e9SAndroid Build Coastguard Worker * (which will switch both your current device and current stream in the way you 253*da0073e9SAndroid Build Coastguard Worker * expect, and reset it back to its original state afterwards). 254*da0073e9SAndroid Build Coastguard Worker */ 255*da0073e9SAndroid Build Coastguard Worker C10_API void setCurrentCUDAStream(CUDAStream stream); 256*da0073e9SAndroid Build Coastguard Worker 257*da0073e9SAndroid Build Coastguard Worker C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s); 258*da0073e9SAndroid Build Coastguard Worker 259*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker namespace std { 262*da0073e9SAndroid Build Coastguard Worker template <> 263*da0073e9SAndroid Build Coastguard Worker struct hash<c10::cuda::CUDAStream> { 264*da0073e9SAndroid Build Coastguard Worker size_t operator()(c10::cuda::CUDAStream s) const noexcept { 265*da0073e9SAndroid Build Coastguard Worker return std::hash<c10::Stream>{}(s.unwrap()); 266*da0073e9SAndroid Build Coastguard Worker } 267*da0073e9SAndroid Build Coastguard Worker }; 268*da0073e9SAndroid Build Coastguard Worker } // namespace std 269