1 #pragma once 2 3 #include <c10/core/DeviceType.h> 4 #include <c10/core/impl/InlineDeviceGuard.h> 5 #include <c10/core/impl/InlineStreamGuard.h> 6 #include <c10/cuda/CUDAMacros.h> 7 #include <c10/cuda/impl/CUDAGuardImpl.h> 8 9 namespace c10::cuda { 10 11 // This code is kind of boilerplatey. See Note [Whither the DeviceGuard 12 // boilerplate] 13 14 /// A variant of DeviceGuard that is specialized for CUDA. It accepts 15 /// integer indices (interpreting them as CUDA devices) and is a little 16 /// more efficient than DeviceGuard (it compiles to straight line 17 /// cudaSetDevice/cudaGetDevice calls); however, it can only be used 18 /// from code that links against CUDA directly. 19 struct CUDAGuard { 20 /// No default constructor; see Note [Omitted default constructor from RAII] 21 explicit CUDAGuard() = delete; 22 23 /// Set the current CUDA device to the passed device index. CUDAGuardCUDAGuard24 explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {} 25 26 /// Sets the current CUDA device to the passed device. Errors if the passed 27 /// device is not a CUDA device. CUDAGuardCUDAGuard28 explicit CUDAGuard(Device device) : guard_(device) {} 29 30 // Copy is not allowed 31 CUDAGuard(const CUDAGuard&) = delete; 32 CUDAGuard& operator=(const CUDAGuard&) = delete; 33 34 // Move is not allowed (there is no uninitialized state) 35 CUDAGuard(CUDAGuard&& other) = delete; 36 CUDAGuard& operator=(CUDAGuard&& other) = delete; 37 38 /// Sets the CUDA device to the given device. Errors if the given device 39 /// is not a CUDA device. set_deviceCUDAGuard40 void set_device(Device device) { 41 guard_.set_device(device); 42 } 43 44 /// Sets the CUDA device to the given device. Errors if the given device 45 /// is not a CUDA device. (This method is provided for uniformity with 46 /// DeviceGuard). reset_deviceCUDAGuard47 void reset_device(Device device) { 48 guard_.reset_device(device); 49 } 50 51 /// Sets the CUDA device to the given device index. set_indexCUDAGuard52 void set_index(DeviceIndex device_index) { 53 guard_.set_index(device_index); 54 } 55 56 /// Returns the device that was set upon construction of the guard original_deviceCUDAGuard57 Device original_device() const { 58 return guard_.original_device(); 59 } 60 61 /// Returns the last device that was set via `set_device`, if any, otherwise 62 /// the device passed during construction. current_deviceCUDAGuard63 Device current_device() const { 64 return guard_.current_device(); 65 } 66 67 private: 68 /// The guard for the current device. 69 c10::impl::InlineDeviceGuard<impl::CUDAGuardImpl> guard_; 70 }; 71 72 /// A variant of OptionalDeviceGuard that is specialized for CUDA. See 73 /// CUDAGuard for when you can use this. 74 struct OptionalCUDAGuard { 75 /// Create an uninitialized OptionalCUDAGuard. OptionalCUDAGuardOptionalCUDAGuard76 explicit OptionalCUDAGuard() : guard_() {} 77 78 /// Set the current CUDA device to the passed Device, if it is not nullopt. OptionalCUDAGuardOptionalCUDAGuard79 explicit OptionalCUDAGuard(std::optional<Device> device_opt) 80 : guard_(device_opt) {} 81 82 /// Set the current CUDA device to the passed device index, if it is not 83 /// nullopt OptionalCUDAGuardOptionalCUDAGuard84 explicit OptionalCUDAGuard(std::optional<DeviceIndex> device_index_opt) 85 : guard_(device_index_opt) {} 86 87 // Copy is not allowed 88 OptionalCUDAGuard(const OptionalCUDAGuard&) = delete; 89 OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete; 90 91 // See Note [Move construction for RAII guards is tricky] 92 OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete; 93 94 // See Note [Move assignment for RAII guards is tricky] 95 OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; 96 97 /// Sets the CUDA device to the given device, initializing the guard if it 98 /// is not already initialized. Errors if the given device is not a CUDA 99 /// device. set_deviceOptionalCUDAGuard100 void set_device(Device device) { 101 guard_.set_device(device); 102 } 103 104 /// Sets the CUDA device to the given device, initializing the guard if it is 105 /// not already initialized. Errors if the given device is not a CUDA device. 106 /// (This method is provided for uniformity with OptionalDeviceGuard). reset_deviceOptionalCUDAGuard107 void reset_device(Device device) { 108 guard_.reset_device(device); 109 } 110 111 /// Sets the CUDA device to the given device index, initializing the guard if 112 /// it is not already initialized. set_indexOptionalCUDAGuard113 void set_index(DeviceIndex device_index) { 114 guard_.set_index(device_index); 115 } 116 117 /// Returns the device that was set immediately prior to initialization of the 118 /// guard, or nullopt if the guard is uninitialized. original_deviceOptionalCUDAGuard119 std::optional<Device> original_device() const { 120 return guard_.original_device(); 121 } 122 123 /// Returns the most recent device that was set using this device guard, 124 /// either from construction, or via set_device, if the guard is initialized, 125 /// or nullopt if the guard is uninitialized. current_deviceOptionalCUDAGuard126 std::optional<Device> current_device() const { 127 return guard_.current_device(); 128 } 129 130 /// Restore the original CUDA device, resetting this guard to uninitialized 131 /// state. resetOptionalCUDAGuard132 void reset() { 133 guard_.reset(); 134 } 135 136 private: 137 c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_; 138 }; 139 140 /// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard 141 /// for when you can use this. 142 struct CUDAStreamGuard { 143 /// No default constructor, see Note [Omitted default constructor from RAII] 144 explicit CUDAStreamGuard() = delete; 145 146 /// Set the current CUDA device to the device associated with the passed 147 /// stream, and set the current CUDA stream on that device to the passed 148 /// stream. Errors if the Stream is not a CUDA stream. CUDAStreamGuardCUDAStreamGuard149 explicit CUDAStreamGuard(Stream stream) : guard_(stream) {} 150 151 /// Copy is disallowed 152 CUDAStreamGuard(const CUDAStreamGuard&) = delete; 153 CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; 154 155 /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized 156 /// state, which is required for moves on types with nontrivial destructors. 157 CUDAStreamGuard(CUDAStreamGuard&& other) = delete; 158 CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; 159 160 /// Resets the currently set stream to the original stream and 161 /// the currently set device to the original device. Then, 162 /// set the current device to the device associated with the passed stream, 163 /// and set the current stream on that device to the passed stream. 164 /// Errors if the stream passed is not a CUDA stream. 165 /// 166 /// NOTE: this implementation may skip some stream/device setting if 167 /// it can prove that it is unnecessary. 168 /// 169 /// WARNING: reset_stream does NOT preserve previously set streams on 170 /// different devices. If you need to set streams on multiple devices 171 /// on CUDA, use CUDAMultiStreamGuard instead. reset_streamCUDAStreamGuard172 void reset_stream(Stream stream) { 173 guard_.reset_stream(stream); 174 } 175 176 /// Returns the CUDA stream that was set at the time the guard was 177 /// constructed. original_streamCUDAStreamGuard178 CUDAStream original_stream() const { 179 return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream()); 180 } 181 182 /// Returns the most recent CUDA stream that was set using this device guard, 183 /// either from construction, or via set_stream. current_streamCUDAStreamGuard184 CUDAStream current_stream() const { 185 return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream()); 186 } 187 188 /// Returns the most recent CUDA device that was set using this device guard, 189 /// either from construction, or via set_device/reset_device/set_index. current_deviceCUDAStreamGuard190 Device current_device() const { 191 return guard_.current_device(); 192 } 193 194 /// Returns the CUDA device that was set at the most recent reset_stream(), 195 /// or otherwise the device at construction time. original_deviceCUDAStreamGuard196 Device original_device() const { 197 return guard_.original_device(); 198 } 199 200 private: 201 c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_; 202 }; 203 204 /// A variant of OptionalStreamGuard that is specialized for CUDA. See 205 /// CUDAGuard for when you can use this. 206 struct OptionalCUDAStreamGuard { 207 /// Create an uninitialized guard. OptionalCUDAStreamGuardOptionalCUDAStreamGuard208 explicit OptionalCUDAStreamGuard() : guard_() {} 209 210 /// Set the current CUDA device to the device associated with the passed 211 /// stream, and set the current CUDA stream on that device to the passed 212 /// stream. Errors if the Stream is not a CUDA stream. OptionalCUDAStreamGuardOptionalCUDAStreamGuard213 explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {} 214 215 /// Set the current device to the device associated with the passed stream, 216 /// and set the current stream on that device to the passed stream, 217 /// if the passed stream is not nullopt. OptionalCUDAStreamGuardOptionalCUDAStreamGuard218 explicit OptionalCUDAStreamGuard(std::optional<Stream> stream_opt) 219 : guard_(stream_opt) {} 220 221 /// Copy is disallowed 222 OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete; 223 OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete; 224 225 // See Note [Move construction for RAII guards is tricky] 226 OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete; 227 228 // See Note [Move assignment for RAII guards is tricky] 229 OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete; 230 231 /// Resets the currently set CUDA stream to the original stream and 232 /// the currently set device to the original device. Then, 233 /// set the current device to the device associated with the passed stream, 234 /// and set the current stream on that device to the passed stream. 235 /// Initializes the guard if it was not previously initialized. reset_streamOptionalCUDAStreamGuard236 void reset_stream(Stream stream) { 237 guard_.reset_stream(stream); 238 } 239 240 /// Returns the CUDA stream that was set at the time the guard was most 241 /// recently initialized, or nullopt if the guard is uninitialized. original_streamOptionalCUDAStreamGuard242 std::optional<CUDAStream> original_stream() const { 243 auto r = guard_.original_stream(); 244 if (r.has_value()) { 245 return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); 246 } else { 247 return std::nullopt; 248 } 249 } 250 251 /// Returns the most recent CUDA stream that was set using this stream guard, 252 /// either from construction, or via reset_stream, if the guard is 253 /// initialized, or nullopt if the guard is uninitialized. current_streamOptionalCUDAStreamGuard254 std::optional<CUDAStream> current_stream() const { 255 auto r = guard_.current_stream(); 256 if (r.has_value()) { 257 return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); 258 } else { 259 return std::nullopt; 260 } 261 } 262 263 /// Restore the original CUDA device and stream, resetting this guard to 264 /// uninitialized state. resetOptionalCUDAStreamGuard265 void reset() { 266 guard_.reset(); 267 } 268 269 private: 270 c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_; 271 }; 272 273 /// A variant of MultiStreamGuard that is specialized for CUDA. 274 struct CUDAMultiStreamGuard { CUDAMultiStreamGuardCUDAMultiStreamGuard275 explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams) 276 : guard_(unwrapStreams(streams)) {} 277 278 /// Copy is disallowed 279 CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete; 280 CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete; 281 282 // See Note [Move construction for RAII guards is tricky] 283 CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete; 284 285 // See Note [Move assignment for RAII guards is tricky] 286 CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete; 287 288 private: 289 c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_; 290 unwrapStreamsCUDAMultiStreamGuard291 static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) { 292 std::vector<Stream> streams; 293 streams.reserve(cudaStreams.size()); 294 for (const CUDAStream& cudaStream : cudaStreams) { 295 streams.push_back(cudaStream); 296 } 297 return streams; 298 } 299 }; 300 301 } // namespace c10::cuda 302