xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDAStream.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker 
3*da0073e9SAndroid Build Coastguard Worker #include <cuda_runtime_api.h>
4*da0073e9SAndroid Build Coastguard Worker 
5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DeviceGuard.h>
6*da0073e9SAndroid Build Coastguard Worker #include <c10/core/Stream.h>
7*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAFunctions.h>
8*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
9*da0073e9SAndroid Build Coastguard Worker 
10*da0073e9SAndroid Build Coastguard Worker /*
11*da0073e9SAndroid Build Coastguard Worker  * Stream pool note.
12*da0073e9SAndroid Build Coastguard Worker  *
13*da0073e9SAndroid Build Coastguard Worker  * A CUDAStream is an abstraction of an actual cuStream on the GPU. CUDAStreams
14*da0073e9SAndroid Build Coastguard Worker  * are backed by cuStreams, but they use several pools to minimize the costs
15*da0073e9SAndroid Build Coastguard Worker  * associated with creating, retaining, and destroying cuStreams.
16*da0073e9SAndroid Build Coastguard Worker  *
17*da0073e9SAndroid Build Coastguard Worker  * There are three pools per device, and a device's pools are lazily created.
18*da0073e9SAndroid Build Coastguard Worker  *
19*da0073e9SAndroid Build Coastguard Worker  * The first pool contains only the default stream. When the default stream
20*da0073e9SAndroid Build Coastguard Worker  * is requested it's returned.
21*da0073e9SAndroid Build Coastguard Worker  *
22*da0073e9SAndroid Build Coastguard Worker  * The second pool is the "low priority" or "default priority" streams. In
23*da0073e9SAndroid Build Coastguard Worker  * HIP builds there is no distinction between streams in this pool and streams
24*da0073e9SAndroid Build Coastguard Worker  * in the third pool (below). There are 32 of these streams per device, and
25*da0073e9SAndroid Build Coastguard Worker  * when a stream is requested one of these streams is returned round-robin.
26*da0073e9SAndroid Build Coastguard Worker  * That is, the first stream requested is at index 0, the second at index 1...
27*da0073e9SAndroid Build Coastguard Worker  * to index 31, then index 0 again.
28*da0073e9SAndroid Build Coastguard Worker  *
29*da0073e9SAndroid Build Coastguard Worker  * This means that if 33 low priority streams are requested, the first and
30*da0073e9SAndroid Build Coastguard Worker  * last streams requested are actually the same stream (under the covers)
31*da0073e9SAndroid Build Coastguard Worker  * and kernels enqueued on them cannot run concurrently.
32*da0073e9SAndroid Build Coastguard Worker  *
33*da0073e9SAndroid Build Coastguard Worker  * The third pool is the "high priority" streams. The third pool acts like
34*da0073e9SAndroid Build Coastguard Worker  * the second pool except the streams are created with a higher priority.
35*da0073e9SAndroid Build Coastguard Worker  *
36*da0073e9SAndroid Build Coastguard Worker  * These pools suggest that stream users should prefer many short-lived streams,
37*da0073e9SAndroid Build Coastguard Worker  * as the cost of acquiring and releasing streams is effectively zero. If
38*da0073e9SAndroid Build Coastguard Worker  * many longer-lived streams are required in performance critical scenarios
39*da0073e9SAndroid Build Coastguard Worker  * then the functionality here may need to be extended to allow, for example,
40*da0073e9SAndroid Build Coastguard Worker  * "reserving" a subset of the pool so that other streams do not accidentally
41*da0073e9SAndroid Build Coastguard Worker  * overlap the performance critical streams.
42*da0073e9SAndroid Build Coastguard Worker  *
43*da0073e9SAndroid Build Coastguard Worker  * Note: although the notion of "current stream for device" is thread local
44*da0073e9SAndroid Build Coastguard Worker  * (every OS thread has a separate current stream, as one might expect),
45*da0073e9SAndroid Build Coastguard Worker  * the stream pool is global across all threads; stream 0 is always stream 0
46*da0073e9SAndroid Build Coastguard Worker  * no matter which thread you use it on.  Multiple threads can synchronize
47*da0073e9SAndroid Build Coastguard Worker  * on the same stream.  Although the CUDA documentation is not very clear
48*da0073e9SAndroid Build Coastguard Worker  * on the matter, streams are thread safe; e.g., it is safe to enqueue
49*da0073e9SAndroid Build Coastguard Worker  * a kernel on the same stream from two different threads.
50*da0073e9SAndroid Build Coastguard Worker  */
51*da0073e9SAndroid Build Coastguard Worker 
52*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
53*da0073e9SAndroid Build Coastguard Worker 
54*da0073e9SAndroid Build Coastguard Worker static constexpr int max_compile_time_stream_priorities = 4;
55*da0073e9SAndroid Build Coastguard Worker 
56*da0073e9SAndroid Build Coastguard Worker // Value object representing a CUDA stream.  This is just a wrapper
57*da0073e9SAndroid Build Coastguard Worker // around c10::Stream, but it comes with a little extra CUDA-specific
58*da0073e9SAndroid Build Coastguard Worker // functionality (conversion to cudaStream_t), and a guarantee that
59*da0073e9SAndroid Build Coastguard Worker // the wrapped c10::Stream really is a CUDA stream.
60*da0073e9SAndroid Build Coastguard Worker class C10_CUDA_API CUDAStream {
61*da0073e9SAndroid Build Coastguard Worker  public:
62*da0073e9SAndroid Build Coastguard Worker   enum Unchecked { UNCHECKED };
63*da0073e9SAndroid Build Coastguard Worker 
64*da0073e9SAndroid Build Coastguard Worker   /// Construct a CUDAStream from a Stream.  This construction is checked,
65*da0073e9SAndroid Build Coastguard Worker   /// and will raise an error if the Stream is not, in fact, a CUDA stream.
CUDAStream(Stream stream)66*da0073e9SAndroid Build Coastguard Worker   explicit CUDAStream(Stream stream) : stream_(stream) {
67*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(stream_.device_type() == DeviceType::CUDA);
68*da0073e9SAndroid Build Coastguard Worker   }
69*da0073e9SAndroid Build Coastguard Worker 
70*da0073e9SAndroid Build Coastguard Worker   /// Construct a CUDAStream from a Stream with no error checking.
71*da0073e9SAndroid Build Coastguard Worker   /// This constructor uses the "named" constructor idiom, and can
72*da0073e9SAndroid Build Coastguard Worker   /// be invoked as: CUDAStream(CUDAStream::UNCHECKED, stream)
CUDAStream(Unchecked,Stream stream)73*da0073e9SAndroid Build Coastguard Worker   explicit CUDAStream(Unchecked, Stream stream) : stream_(stream) {}
74*da0073e9SAndroid Build Coastguard Worker 
75*da0073e9SAndroid Build Coastguard Worker   bool operator==(const CUDAStream& other) const noexcept {
76*da0073e9SAndroid Build Coastguard Worker     return unwrap() == other.unwrap();
77*da0073e9SAndroid Build Coastguard Worker   }
78*da0073e9SAndroid Build Coastguard Worker 
79*da0073e9SAndroid Build Coastguard Worker   bool operator!=(const CUDAStream& other) const noexcept {
80*da0073e9SAndroid Build Coastguard Worker     return unwrap() != other.unwrap();
81*da0073e9SAndroid Build Coastguard Worker   }
82*da0073e9SAndroid Build Coastguard Worker 
83*da0073e9SAndroid Build Coastguard Worker   /// Implicit conversion to cudaStream_t.
cudaStream_t()84*da0073e9SAndroid Build Coastguard Worker   operator cudaStream_t() const {
85*da0073e9SAndroid Build Coastguard Worker     return stream();
86*da0073e9SAndroid Build Coastguard Worker   }
87*da0073e9SAndroid Build Coastguard Worker 
88*da0073e9SAndroid Build Coastguard Worker   /// Implicit conversion to Stream (a.k.a., forget that the stream is a
89*da0073e9SAndroid Build Coastguard Worker   /// CUDA stream).
Stream()90*da0073e9SAndroid Build Coastguard Worker   operator Stream() const {
91*da0073e9SAndroid Build Coastguard Worker     return unwrap();
92*da0073e9SAndroid Build Coastguard Worker   }
93*da0073e9SAndroid Build Coastguard Worker 
94*da0073e9SAndroid Build Coastguard Worker   /// Used to avoid baking in device type explicitly to Python-side API.
device_type()95*da0073e9SAndroid Build Coastguard Worker   DeviceType device_type() const {
96*da0073e9SAndroid Build Coastguard Worker     return DeviceType::CUDA;
97*da0073e9SAndroid Build Coastguard Worker   }
98*da0073e9SAndroid Build Coastguard Worker 
99*da0073e9SAndroid Build Coastguard Worker   /// Get the CUDA device index that this stream is associated with.
device_index()100*da0073e9SAndroid Build Coastguard Worker   DeviceIndex device_index() const {
101*da0073e9SAndroid Build Coastguard Worker     return stream_.device_index();
102*da0073e9SAndroid Build Coastguard Worker   }
103*da0073e9SAndroid Build Coastguard Worker 
104*da0073e9SAndroid Build Coastguard Worker   /// Get the full Device that this stream is associated with.  The Device
105*da0073e9SAndroid Build Coastguard Worker   /// is guaranteed to be a CUDA device.
device()106*da0073e9SAndroid Build Coastguard Worker   Device device() const {
107*da0073e9SAndroid Build Coastguard Worker     return Device(DeviceType::CUDA, device_index());
108*da0073e9SAndroid Build Coastguard Worker   }
109*da0073e9SAndroid Build Coastguard Worker 
110*da0073e9SAndroid Build Coastguard Worker   /// Return the stream ID corresponding to this particular stream.
id()111*da0073e9SAndroid Build Coastguard Worker   StreamId id() const {
112*da0073e9SAndroid Build Coastguard Worker     return stream_.id();
113*da0073e9SAndroid Build Coastguard Worker   }
114*da0073e9SAndroid Build Coastguard Worker 
query()115*da0073e9SAndroid Build Coastguard Worker   bool query() const {
116*da0073e9SAndroid Build Coastguard Worker     DeviceGuard guard{stream_.device()};
117*da0073e9SAndroid Build Coastguard Worker     cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaStreamQuery(stream()));
118*da0073e9SAndroid Build Coastguard Worker 
119*da0073e9SAndroid Build Coastguard Worker     if (err == cudaSuccess) {
120*da0073e9SAndroid Build Coastguard Worker       return true;
121*da0073e9SAndroid Build Coastguard Worker     } else if (err != cudaErrorNotReady) {
122*da0073e9SAndroid Build Coastguard Worker       C10_CUDA_CHECK(err);
123*da0073e9SAndroid Build Coastguard Worker     } else {
124*da0073e9SAndroid Build Coastguard Worker       // ignore and clear the error if not ready
125*da0073e9SAndroid Build Coastguard Worker       (void)cudaGetLastError();
126*da0073e9SAndroid Build Coastguard Worker     }
127*da0073e9SAndroid Build Coastguard Worker 
128*da0073e9SAndroid Build Coastguard Worker     return false;
129*da0073e9SAndroid Build Coastguard Worker   }
130*da0073e9SAndroid Build Coastguard Worker 
synchronize()131*da0073e9SAndroid Build Coastguard Worker   void synchronize() const {
132*da0073e9SAndroid Build Coastguard Worker     DeviceGuard guard{stream_.device()};
133*da0073e9SAndroid Build Coastguard Worker     c10::cuda::stream_synchronize(stream());
134*da0073e9SAndroid Build Coastguard Worker   }
135*da0073e9SAndroid Build Coastguard Worker 
priority()136*da0073e9SAndroid Build Coastguard Worker   int priority() const {
137*da0073e9SAndroid Build Coastguard Worker     DeviceGuard guard{stream_.device()};
138*da0073e9SAndroid Build Coastguard Worker     int priority = 0;
139*da0073e9SAndroid Build Coastguard Worker     C10_CUDA_CHECK(cudaStreamGetPriority(stream(), &priority));
140*da0073e9SAndroid Build Coastguard Worker     return priority;
141*da0073e9SAndroid Build Coastguard Worker   }
142*da0073e9SAndroid Build Coastguard Worker 
143*da0073e9SAndroid Build Coastguard Worker   /// Explicit conversion to cudaStream_t.
144*da0073e9SAndroid Build Coastguard Worker   cudaStream_t stream() const;
145*da0073e9SAndroid Build Coastguard Worker 
146*da0073e9SAndroid Build Coastguard Worker   /// Explicit conversion to Stream.
unwrap()147*da0073e9SAndroid Build Coastguard Worker   Stream unwrap() const {
148*da0073e9SAndroid Build Coastguard Worker     return stream_;
149*da0073e9SAndroid Build Coastguard Worker   }
150*da0073e9SAndroid Build Coastguard Worker 
151*da0073e9SAndroid Build Coastguard Worker   /// Reversibly pack a CUDAStream into a struct representation.
152*da0073e9SAndroid Build Coastguard Worker   /// Previously the stream's data was packed into a single int64_t,
153*da0073e9SAndroid Build Coastguard Worker   /// as it was assumed the fields would not require more than
154*da0073e9SAndroid Build Coastguard Worker   /// 64 bits of storage in total.
155*da0073e9SAndroid Build Coastguard Worker   /// See https://github.com/pytorch/pytorch/issues/75854
156*da0073e9SAndroid Build Coastguard Worker   /// for more information regarding newer platforms that may violate
157*da0073e9SAndroid Build Coastguard Worker   /// this assumption.
158*da0073e9SAndroid Build Coastguard Worker   ///
159*da0073e9SAndroid Build Coastguard Worker   /// The CUDAStream can be unpacked using unpack().
pack3()160*da0073e9SAndroid Build Coastguard Worker   struct c10::StreamData3 pack3() const {
161*da0073e9SAndroid Build Coastguard Worker     return stream_.pack3();
162*da0073e9SAndroid Build Coastguard Worker   }
163*da0073e9SAndroid Build Coastguard Worker 
164*da0073e9SAndroid Build Coastguard Worker   // Unpack a CUDAStream from the 3 fields generated by pack().
unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)165*da0073e9SAndroid Build Coastguard Worker   static CUDAStream unpack3(
166*da0073e9SAndroid Build Coastguard Worker       StreamId stream_id,
167*da0073e9SAndroid Build Coastguard Worker       DeviceIndex device_index,
168*da0073e9SAndroid Build Coastguard Worker       DeviceType device_type) {
169*da0073e9SAndroid Build Coastguard Worker     return CUDAStream(Stream::unpack3(stream_id, device_index, device_type));
170*da0073e9SAndroid Build Coastguard Worker   }
171*da0073e9SAndroid Build Coastguard Worker 
priority_range()172*da0073e9SAndroid Build Coastguard Worker   static std::tuple<int, int> priority_range() {
173*da0073e9SAndroid Build Coastguard Worker     // Note: this returns the range of priority **supported by PyTorch**, not
174*da0073e9SAndroid Build Coastguard Worker     // the range of priority **supported by CUDA**. The former is a subset of
175*da0073e9SAndroid Build Coastguard Worker     // the latter.
176*da0073e9SAndroid Build Coastguard Worker     int least_priority = 0, greatest_priority = 0;
177*da0073e9SAndroid Build Coastguard Worker     C10_CUDA_CHECK(
178*da0073e9SAndroid Build Coastguard Worker         cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
179*da0073e9SAndroid Build Coastguard Worker #ifdef USE_ROCM
180*da0073e9SAndroid Build Coastguard Worker     // See Note [HIP stream priorities]
181*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
182*da0073e9SAndroid Build Coastguard Worker         least_priority == 1, "Unexpected HIP stream priority range");
183*da0073e9SAndroid Build Coastguard Worker     least_priority = 0;
184*da0073e9SAndroid Build Coastguard Worker #else
185*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
186*da0073e9SAndroid Build Coastguard Worker         least_priority == 0, "Unexpected CUDA stream priority range");
187*da0073e9SAndroid Build Coastguard Worker #endif
188*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
189*da0073e9SAndroid Build Coastguard Worker         greatest_priority <= -1, "Unexpected CUDA stream priority range");
190*da0073e9SAndroid Build Coastguard Worker     greatest_priority = std::max(
191*da0073e9SAndroid Build Coastguard Worker         -c10::cuda::max_compile_time_stream_priorities + 1, greatest_priority);
192*da0073e9SAndroid Build Coastguard Worker     return std::make_tuple(least_priority, greatest_priority);
193*da0073e9SAndroid Build Coastguard Worker   }
194*da0073e9SAndroid Build Coastguard Worker 
195*da0073e9SAndroid Build Coastguard Worker   // Deleted for now; use CUDAEvent::block instead
196*da0073e9SAndroid Build Coastguard Worker   // void synchronize_with(const CUDAEvent& event) const;
197*da0073e9SAndroid Build Coastguard Worker 
198*da0073e9SAndroid Build Coastguard Worker  private:
199*da0073e9SAndroid Build Coastguard Worker   Stream stream_;
200*da0073e9SAndroid Build Coastguard Worker };
201*da0073e9SAndroid Build Coastguard Worker 
202*da0073e9SAndroid Build Coastguard Worker /**
203*da0073e9SAndroid Build Coastguard Worker  * Get a new stream from the CUDA stream pool.  You can think of this
204*da0073e9SAndroid Build Coastguard Worker  * as "creating" a new stream, but no such creation actually happens;
205*da0073e9SAndroid Build Coastguard Worker  * instead, streams are preallocated from the pool and returned in a
206*da0073e9SAndroid Build Coastguard Worker  * round-robin fashion.
207*da0073e9SAndroid Build Coastguard Worker  *
208*da0073e9SAndroid Build Coastguard Worker  * You can request a stream from the high priority pool by setting
209*da0073e9SAndroid Build Coastguard Worker  * isHighPriority to true, or a stream for a specific device by setting device
210*da0073e9SAndroid Build Coastguard Worker  * (defaulting to the current CUDA stream.)
211*da0073e9SAndroid Build Coastguard Worker  */
212*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream
213*da0073e9SAndroid Build Coastguard Worker getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1);
214*da0073e9SAndroid Build Coastguard Worker // no default priority to disambiguate overloads
215*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream
216*da0073e9SAndroid Build Coastguard Worker getStreamFromPool(const int priority, DeviceIndex device = -1);
217*da0073e9SAndroid Build Coastguard Worker 
218*da0073e9SAndroid Build Coastguard Worker /**
219*da0073e9SAndroid Build Coastguard Worker  * Get a CUDAStream from a externally allocated one.
220*da0073e9SAndroid Build Coastguard Worker  *
221*da0073e9SAndroid Build Coastguard Worker  * This is mainly for interoperability with different libraries where we
222*da0073e9SAndroid Build Coastguard Worker  * want to operate on a non-torch allocated stream for data exchange or similar
223*da0073e9SAndroid Build Coastguard Worker  * purposes
224*da0073e9SAndroid Build Coastguard Worker  */
225*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream
226*da0073e9SAndroid Build Coastguard Worker getStreamFromExternal(cudaStream_t ext_stream, DeviceIndex device_index);
227*da0073e9SAndroid Build Coastguard Worker 
228*da0073e9SAndroid Build Coastguard Worker /**
229*da0073e9SAndroid Build Coastguard Worker  * Get the default CUDA stream, for the passed CUDA device, or for the
230*da0073e9SAndroid Build Coastguard Worker  * current device if no device index is passed.  The default stream is
231*da0073e9SAndroid Build Coastguard Worker  * where most computation occurs when you aren't explicitly using
232*da0073e9SAndroid Build Coastguard Worker  * streams.
233*da0073e9SAndroid Build Coastguard Worker  */
234*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream getDefaultCUDAStream(DeviceIndex device_index = -1);
235*da0073e9SAndroid Build Coastguard Worker 
236*da0073e9SAndroid Build Coastguard Worker /**
237*da0073e9SAndroid Build Coastguard Worker  * Get the current CUDA stream, for the passed CUDA device, or for the
238*da0073e9SAndroid Build Coastguard Worker  * current device if no device index is passed.  The current CUDA stream
239*da0073e9SAndroid Build Coastguard Worker  * will usually be the default CUDA stream for the device, but it may
240*da0073e9SAndroid Build Coastguard Worker  * be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
241*da0073e9SAndroid Build Coastguard Worker  * or 'CUDAStreamGuard'.
242*da0073e9SAndroid Build Coastguard Worker  */
243*da0073e9SAndroid Build Coastguard Worker C10_API CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1);
244*da0073e9SAndroid Build Coastguard Worker 
245*da0073e9SAndroid Build Coastguard Worker /**
246*da0073e9SAndroid Build Coastguard Worker  * Set the current stream on the device of the passed in stream to be
247*da0073e9SAndroid Build Coastguard Worker  * the passed in stream.  Yes, you read that right: this function
248*da0073e9SAndroid Build Coastguard Worker  * has *nothing* to do with the current device: it toggles the current
249*da0073e9SAndroid Build Coastguard Worker  * stream of the device of the passed stream.
250*da0073e9SAndroid Build Coastguard Worker  *
251*da0073e9SAndroid Build Coastguard Worker  * Confused?  Avoid using this function; prefer using 'CUDAStreamGuard' instead
252*da0073e9SAndroid Build Coastguard Worker  * (which will switch both your current device and current stream in the way you
253*da0073e9SAndroid Build Coastguard Worker  * expect, and reset it back to its original state afterwards).
254*da0073e9SAndroid Build Coastguard Worker  */
255*da0073e9SAndroid Build Coastguard Worker C10_API void setCurrentCUDAStream(CUDAStream stream);
256*da0073e9SAndroid Build Coastguard Worker 
257*da0073e9SAndroid Build Coastguard Worker C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s);
258*da0073e9SAndroid Build Coastguard Worker 
259*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
260*da0073e9SAndroid Build Coastguard Worker 
261*da0073e9SAndroid Build Coastguard Worker namespace std {
262*da0073e9SAndroid Build Coastguard Worker template <>
263*da0073e9SAndroid Build Coastguard Worker struct hash<c10::cuda::CUDAStream> {
264*da0073e9SAndroid Build Coastguard Worker   size_t operator()(c10::cuda::CUDAStream s) const noexcept {
265*da0073e9SAndroid Build Coastguard Worker     return std::hash<c10::Stream>{}(s.unwrap());
266*da0073e9SAndroid Build Coastguard Worker   }
267*da0073e9SAndroid Build Coastguard Worker };
268*da0073e9SAndroid Build Coastguard Worker } // namespace std
269