xref: /aosp_15_r20/external/pytorch/c10/core/Stream.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/DeviceType.h>
5 #include <c10/macros/Export.h>
6 #include <c10/util/Exception.h>
7 #include <cstddef>
8 #include <cstdint>
9 #include <functional>
10 #include <ostream>
11 
12 namespace c10 {
13 
14 /// An index representing a specific stream.  A StreamId is not independently
15 /// meaningful without knowing the Device it is associated with; try to
16 /// use Stream rather than StreamId directly.
17 ///
18 /// StreamIds are opaque; they are assigned by some DeviceType-specific
19 /// numbering system which is not visible to the user.  HOWEVER, we
20 /// guarantee that StreamId 0 is always a valid stream, and corresponds
21 /// to some sort of "default" stream.
22 using StreamId = int64_t;
23 
24 struct C10_API StreamData3 {
25   StreamId stream_id;
26   DeviceIndex device_index;
27   DeviceType device_type;
28 };
29 
30 // NB: I decided not to call the above StreamIndex to avoid confusion with
31 // DeviceIndex.  This way, you access device index with index(), and stream id
32 // with id()
33 
34 /**
35  * A stream is a software mechanism used to synchronize launched kernels
36  * without requiring explicit synchronizations between kernels.  The basic
37  * model is that every kernel launch is associated with a stream: every
38  * kernel on the same stream is implicitly synchronized so that if I launch
39  * kernels A and B on the same stream, A is guaranteed to finish before B
40  * launches.  If I want B to run concurrently with A, I must schedule
41  * it on a different stream.
42  *
43  * The Stream class is a backend agnostic value class representing a stream
44  * which I may schedule a kernel on.  Every stream is associated with a device,
45  * which is recorded in stream, which is used to avoid confusion about which
46  * device a stream refers to.
47  *
48  * Streams are explicitly thread-safe, in the sense that it is OK to pass
49  * a Stream from one thread to another, and kernels queued from two different
50  * threads will still get serialized appropriately.  (Of course, the
51  * time when the kernels get queued is undetermined unless you synchronize
52  * host side ;)
53  *
54  * Stream does NOT have a default constructor.  Streams are for expert
55  * users; if you want to use Streams, we're going to assume you know
56  * how to deal with C++ template error messages if you try to
57  * resize() a vector of Streams.
58  *
59  * Known instances of streams in backends:
60  *
61  *  - cudaStream_t (CUDA)
62  *  - hipStream_t (HIP)
63  *  - cl_command_queue (OpenCL)  (NB: Caffe2's existing OpenCL integration
64  *    does NOT support command queues.)
65  *
66  * Because this class is device agnostic, it cannot provide backend-specific
67  * functionality (e.g., get the cudaStream_t of a CUDA stream.)  There are
68  * wrapper classes which provide this functionality, e.g., CUDAStream.
69  */
70 class C10_API Stream final {
71  private:
72   Device device_;
73   StreamId id_;
74 
75  public:
76   enum Unsafe { UNSAFE };
77   enum Default { DEFAULT };
78 
79   /// Unsafely construct a stream from a Device and a StreamId.  In
80   /// general, only specific implementations of streams for a
81   /// backend should manufacture Stream directly in this way; other users
82   /// should use the provided APIs to get a stream.  In particular,
83   /// we don't require backends to give any guarantees about non-zero
84   /// StreamIds; they are welcome to allocate in whatever way they like.
Stream(Unsafe,Device device,StreamId id)85   explicit Stream(Unsafe, Device device, StreamId id)
86       : device_(device), id_(id) {}
87 
88   /// Construct the default stream of a Device.  The default stream is
89   /// NOT the same as the current stream; default stream is a fixed stream
90   /// that never changes, whereas the current stream may be changed by
91   /// StreamGuard.
Stream(Default,Device device)92   explicit Stream(Default, Device device) : device_(device), id_(0) {}
93 
94   bool operator==(const Stream& other) const noexcept {
95     return this->device_ == other.device_ && this->id_ == other.id_;
96   }
97   bool operator!=(const Stream& other) const noexcept {
98     return !(*this == other);
99   }
100 
device()101   Device device() const noexcept {
102     return device_;
103   }
device_type()104   DeviceType device_type() const noexcept {
105     return device_.type();
106   }
device_index()107   DeviceIndex device_index() const noexcept {
108     return device_.index();
109   }
id()110   StreamId id() const noexcept {
111     return id_;
112   }
113 
114   // Enqueues a wait instruction in the stream's work queue.
115   // This instruction is a no-op unless the event is marked
116   // for recording. In that case the stream stops processing
117   // until the event is recorded.
118   template <typename T>
wait(const T & event)119   void wait(const T& event) const {
120     event.block(*this);
121   }
122 
123   // Return whether all asynchronous work previously enqueued on this stream
124   // has completed running on the device.
125   bool query() const;
126 
127   // Wait (by blocking the calling thread) until all asynchronous work enqueued
128   // on this stream has completed running on the device.
129   void synchronize() const;
130 
131   // The purpose of this function is to more conveniently permit binding
132   // of Stream to and from Python.  Without packing, I have to setup a whole
133   // class with two fields (device and stream id); with packing I can just
134   // store a single uint64_t.
135   //
136   // The particular way we pack streams into a uint64_t is considered an
137   // implementation detail and should not be relied upon.
hash()138   uint64_t hash() const noexcept {
139     // Concat these together into a 64-bit integer
140     uint64_t bits = static_cast<uint64_t>(device_type()) << 56 |
141         static_cast<uint64_t>(device_index()) << 48 |
142         // Remove the sign extension part of the 64-bit address because
143         // the id might be used to hold a pointer.
144         (static_cast<uint64_t>(id()) & ((1ull << 48) - 1));
145     return bits;
146   }
147 
pack3()148   struct StreamData3 pack3() const {
149     return {id(), device_index(), device_type()};
150   }
151 
unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)152   static Stream unpack3(
153       StreamId stream_id,
154       DeviceIndex device_index,
155       DeviceType device_type) {
156     TORCH_CHECK(isValidDeviceType(device_type));
157     return Stream(UNSAFE, Device(device_type, device_index), stream_id);
158   }
159 
160   // I decided NOT to provide setters on this class, because really,
161   // why would you change the device of a stream?  Just construct
162   // it correctly from the beginning dude.
163 };
164 
165 C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s);
166 
167 } // namespace c10
168 
169 namespace std {
170 template <>
171 struct hash<c10::Stream> {
172   size_t operator()(c10::Stream s) const noexcept {
173     return std::hash<uint64_t>{}(s.hash());
174   }
175 };
176 } // namespace std
177