1 #pragma once 2 3 #include <c10/core/Device.h> 4 #include <c10/core/DeviceType.h> 5 #include <c10/macros/Export.h> 6 #include <c10/util/Exception.h> 7 #include <cstddef> 8 #include <cstdint> 9 #include <functional> 10 #include <ostream> 11 12 namespace c10 { 13 14 /// An index representing a specific stream. A StreamId is not independently 15 /// meaningful without knowing the Device it is associated with; try to 16 /// use Stream rather than StreamId directly. 17 /// 18 /// StreamIds are opaque; they are assigned by some DeviceType-specific 19 /// numbering system which is not visible to the user. HOWEVER, we 20 /// guarantee that StreamId 0 is always a valid stream, and corresponds 21 /// to some sort of "default" stream. 22 using StreamId = int64_t; 23 24 struct C10_API StreamData3 { 25 StreamId stream_id; 26 DeviceIndex device_index; 27 DeviceType device_type; 28 }; 29 30 // NB: I decided not to call the above StreamIndex to avoid confusion with 31 // DeviceIndex. This way, you access device index with index(), and stream id 32 // with id() 33 34 /** 35 * A stream is a software mechanism used to synchronize launched kernels 36 * without requiring explicit synchronizations between kernels. The basic 37 * model is that every kernel launch is associated with a stream: every 38 * kernel on the same stream is implicitly synchronized so that if I launch 39 * kernels A and B on the same stream, A is guaranteed to finish before B 40 * launches. If I want B to run concurrently with A, I must schedule 41 * it on a different stream. 42 * 43 * The Stream class is a backend agnostic value class representing a stream 44 * which I may schedule a kernel on. Every stream is associated with a device, 45 * which is recorded in stream, which is used to avoid confusion about which 46 * device a stream refers to. 47 * 48 * Streams are explicitly thread-safe, in the sense that it is OK to pass 49 * a Stream from one thread to another, and kernels queued from two different 50 * threads will still get serialized appropriately. (Of course, the 51 * time when the kernels get queued is undetermined unless you synchronize 52 * host side ;) 53 * 54 * Stream does NOT have a default constructor. Streams are for expert 55 * users; if you want to use Streams, we're going to assume you know 56 * how to deal with C++ template error messages if you try to 57 * resize() a vector of Streams. 58 * 59 * Known instances of streams in backends: 60 * 61 * - cudaStream_t (CUDA) 62 * - hipStream_t (HIP) 63 * - cl_command_queue (OpenCL) (NB: Caffe2's existing OpenCL integration 64 * does NOT support command queues.) 65 * 66 * Because this class is device agnostic, it cannot provide backend-specific 67 * functionality (e.g., get the cudaStream_t of a CUDA stream.) There are 68 * wrapper classes which provide this functionality, e.g., CUDAStream. 69 */ 70 class C10_API Stream final { 71 private: 72 Device device_; 73 StreamId id_; 74 75 public: 76 enum Unsafe { UNSAFE }; 77 enum Default { DEFAULT }; 78 79 /// Unsafely construct a stream from a Device and a StreamId. In 80 /// general, only specific implementations of streams for a 81 /// backend should manufacture Stream directly in this way; other users 82 /// should use the provided APIs to get a stream. In particular, 83 /// we don't require backends to give any guarantees about non-zero 84 /// StreamIds; they are welcome to allocate in whatever way they like. Stream(Unsafe,Device device,StreamId id)85 explicit Stream(Unsafe, Device device, StreamId id) 86 : device_(device), id_(id) {} 87 88 /// Construct the default stream of a Device. The default stream is 89 /// NOT the same as the current stream; default stream is a fixed stream 90 /// that never changes, whereas the current stream may be changed by 91 /// StreamGuard. Stream(Default,Device device)92 explicit Stream(Default, Device device) : device_(device), id_(0) {} 93 94 bool operator==(const Stream& other) const noexcept { 95 return this->device_ == other.device_ && this->id_ == other.id_; 96 } 97 bool operator!=(const Stream& other) const noexcept { 98 return !(*this == other); 99 } 100 device()101 Device device() const noexcept { 102 return device_; 103 } device_type()104 DeviceType device_type() const noexcept { 105 return device_.type(); 106 } device_index()107 DeviceIndex device_index() const noexcept { 108 return device_.index(); 109 } id()110 StreamId id() const noexcept { 111 return id_; 112 } 113 114 // Enqueues a wait instruction in the stream's work queue. 115 // This instruction is a no-op unless the event is marked 116 // for recording. In that case the stream stops processing 117 // until the event is recorded. 118 template <typename T> wait(const T & event)119 void wait(const T& event) const { 120 event.block(*this); 121 } 122 123 // Return whether all asynchronous work previously enqueued on this stream 124 // has completed running on the device. 125 bool query() const; 126 127 // Wait (by blocking the calling thread) until all asynchronous work enqueued 128 // on this stream has completed running on the device. 129 void synchronize() const; 130 131 // The purpose of this function is to more conveniently permit binding 132 // of Stream to and from Python. Without packing, I have to setup a whole 133 // class with two fields (device and stream id); with packing I can just 134 // store a single uint64_t. 135 // 136 // The particular way we pack streams into a uint64_t is considered an 137 // implementation detail and should not be relied upon. hash()138 uint64_t hash() const noexcept { 139 // Concat these together into a 64-bit integer 140 uint64_t bits = static_cast<uint64_t>(device_type()) << 56 | 141 static_cast<uint64_t>(device_index()) << 48 | 142 // Remove the sign extension part of the 64-bit address because 143 // the id might be used to hold a pointer. 144 (static_cast<uint64_t>(id()) & ((1ull << 48) - 1)); 145 return bits; 146 } 147 pack3()148 struct StreamData3 pack3() const { 149 return {id(), device_index(), device_type()}; 150 } 151 unpack3(StreamId stream_id,DeviceIndex device_index,DeviceType device_type)152 static Stream unpack3( 153 StreamId stream_id, 154 DeviceIndex device_index, 155 DeviceType device_type) { 156 TORCH_CHECK(isValidDeviceType(device_type)); 157 return Stream(UNSAFE, Device(device_type, device_index), stream_id); 158 } 159 160 // I decided NOT to provide setters on this class, because really, 161 // why would you change the device of a stream? Just construct 162 // it correctly from the beginning dude. 163 }; 164 165 C10_API std::ostream& operator<<(std::ostream& stream, const Stream& s); 166 167 } // namespace c10 168 169 namespace std { 170 template <> 171 struct hash<c10::Stream> { 172 size_t operator()(c10::Stream s) const noexcept { 173 return std::hash<uint64_t>{}(s.hash()); 174 } 175 }; 176 } // namespace std 177