1*da0073e9SAndroid Build Coastguard Worker #pragma once 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Worker #include <c10/core/DeviceType.h> 4*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/InlineDeviceGuard.h> 5*da0073e9SAndroid Build Coastguard Worker #include <c10/core/impl/InlineStreamGuard.h> 6*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAMacros.h> 7*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/impl/CUDAGuardImpl.h> 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda { 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker // This code is kind of boilerplatey. See Note [Whither the DeviceGuard 12*da0073e9SAndroid Build Coastguard Worker // boilerplate] 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker /// A variant of DeviceGuard that is specialized for CUDA. It accepts 15*da0073e9SAndroid Build Coastguard Worker /// integer indices (interpreting them as CUDA devices) and is a little 16*da0073e9SAndroid Build Coastguard Worker /// more efficient than DeviceGuard (it compiles to straight line 17*da0073e9SAndroid Build Coastguard Worker /// cudaSetDevice/cudaGetDevice calls); however, it can only be used 18*da0073e9SAndroid Build Coastguard Worker /// from code that links against CUDA directly. 19*da0073e9SAndroid Build Coastguard Worker struct CUDAGuard { 20*da0073e9SAndroid Build Coastguard Worker /// No default constructor; see Note [Omitted default constructor from RAII] 21*da0073e9SAndroid Build Coastguard Worker explicit CUDAGuard() = delete; 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker /// Set the current CUDA device to the passed device index. CUDAGuardCUDAGuard24*da0073e9SAndroid Build Coastguard Worker explicit CUDAGuard(DeviceIndex device_index) : guard_(device_index) {} 25*da0073e9SAndroid Build Coastguard Worker 26*da0073e9SAndroid Build Coastguard Worker /// Sets the current CUDA device to the passed device. Errors if the passed 27*da0073e9SAndroid Build Coastguard Worker /// device is not a CUDA device. CUDAGuardCUDAGuard28*da0073e9SAndroid Build Coastguard Worker explicit CUDAGuard(Device device) : guard_(device) {} 29*da0073e9SAndroid Build Coastguard Worker 30*da0073e9SAndroid Build Coastguard Worker // Copy is not allowed 31*da0073e9SAndroid Build Coastguard Worker CUDAGuard(const CUDAGuard&) = delete; 32*da0073e9SAndroid Build Coastguard Worker CUDAGuard& operator=(const CUDAGuard&) = delete; 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker // Move is not allowed (there is no uninitialized state) 35*da0073e9SAndroid Build Coastguard Worker CUDAGuard(CUDAGuard&& other) = delete; 36*da0073e9SAndroid Build Coastguard Worker CUDAGuard& operator=(CUDAGuard&& other) = delete; 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker /// Sets the CUDA device to the given device. Errors if the given device 39*da0073e9SAndroid Build Coastguard Worker /// is not a CUDA device. set_deviceCUDAGuard40*da0073e9SAndroid Build Coastguard Worker void set_device(Device device) { 41*da0073e9SAndroid Build Coastguard Worker guard_.set_device(device); 42*da0073e9SAndroid Build Coastguard Worker } 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker /// Sets the CUDA device to the given device. Errors if the given device 45*da0073e9SAndroid Build Coastguard Worker /// is not a CUDA device. (This method is provided for uniformity with 46*da0073e9SAndroid Build Coastguard Worker /// DeviceGuard). reset_deviceCUDAGuard47*da0073e9SAndroid Build Coastguard Worker void reset_device(Device device) { 48*da0073e9SAndroid Build Coastguard Worker guard_.reset_device(device); 49*da0073e9SAndroid Build Coastguard Worker } 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker /// Sets the CUDA device to the given device index. set_indexCUDAGuard52*da0073e9SAndroid Build Coastguard Worker void set_index(DeviceIndex device_index) { 53*da0073e9SAndroid Build Coastguard Worker guard_.set_index(device_index); 54*da0073e9SAndroid Build Coastguard Worker } 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Worker /// Returns the device that was set upon construction of the guard original_deviceCUDAGuard57*da0073e9SAndroid Build Coastguard Worker Device original_device() const { 58*da0073e9SAndroid Build Coastguard Worker return guard_.original_device(); 59*da0073e9SAndroid Build Coastguard Worker } 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Worker /// Returns the last device that was set via `set_device`, if any, otherwise 62*da0073e9SAndroid Build Coastguard Worker /// the device passed during construction. current_deviceCUDAGuard63*da0073e9SAndroid Build Coastguard Worker Device current_device() const { 64*da0073e9SAndroid Build Coastguard Worker return guard_.current_device(); 65*da0073e9SAndroid Build Coastguard Worker } 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker private: 68*da0073e9SAndroid Build Coastguard Worker /// The guard for the current device. 69*da0073e9SAndroid Build Coastguard Worker c10::impl::InlineDeviceGuard<impl::CUDAGuardImpl> guard_; 70*da0073e9SAndroid Build Coastguard Worker }; 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Worker /// A variant of OptionalDeviceGuard that is specialized for CUDA. See 73*da0073e9SAndroid Build Coastguard Worker /// CUDAGuard for when you can use this. 74*da0073e9SAndroid Build Coastguard Worker struct OptionalCUDAGuard { 75*da0073e9SAndroid Build Coastguard Worker /// Create an uninitialized OptionalCUDAGuard. OptionalCUDAGuardOptionalCUDAGuard76*da0073e9SAndroid Build Coastguard Worker explicit OptionalCUDAGuard() : guard_() {} 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Worker /// Set the current CUDA device to the passed Device, if it is not nullopt. OptionalCUDAGuardOptionalCUDAGuard79*da0073e9SAndroid Build Coastguard Worker explicit OptionalCUDAGuard(std::optional<Device> device_opt) 80*da0073e9SAndroid Build Coastguard Worker : guard_(device_opt) {} 81*da0073e9SAndroid Build Coastguard Worker 82*da0073e9SAndroid Build Coastguard Worker /// Set the current CUDA device to the passed device index, if it is not 83*da0073e9SAndroid Build Coastguard Worker /// nullopt OptionalCUDAGuardOptionalCUDAGuard84*da0073e9SAndroid Build Coastguard Worker explicit OptionalCUDAGuard(std::optional<DeviceIndex> device_index_opt) 85*da0073e9SAndroid Build Coastguard Worker : guard_(device_index_opt) {} 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker // Copy is not allowed 88*da0073e9SAndroid Build Coastguard Worker OptionalCUDAGuard(const OptionalCUDAGuard&) = delete; 89*da0073e9SAndroid Build Coastguard Worker OptionalCUDAGuard& operator=(const OptionalCUDAGuard&) = delete; 90*da0073e9SAndroid Build Coastguard Worker 91*da0073e9SAndroid Build Coastguard Worker // See Note [Move construction for RAII guards is tricky] 92*da0073e9SAndroid Build Coastguard Worker OptionalCUDAGuard(OptionalCUDAGuard&& other) = delete; 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker // See Note [Move assignment for RAII guards is tricky] 95*da0073e9SAndroid Build Coastguard Worker OptionalCUDAGuard& operator=(OptionalCUDAGuard&& other) = delete; 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker /// Sets the CUDA device to the given device, initializing the guard if it 98*da0073e9SAndroid Build Coastguard Worker /// is not already initialized. Errors if the given device is not a CUDA 99*da0073e9SAndroid Build Coastguard Worker /// device. set_deviceOptionalCUDAGuard100*da0073e9SAndroid Build Coastguard Worker void set_device(Device device) { 101*da0073e9SAndroid Build Coastguard Worker guard_.set_device(device); 102*da0073e9SAndroid Build Coastguard Worker } 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker /// Sets the CUDA device to the given device, initializing the guard if it is 105*da0073e9SAndroid Build Coastguard Worker /// not already initialized. Errors if the given device is not a CUDA device. 106*da0073e9SAndroid Build Coastguard Worker /// (This method is provided for uniformity with OptionalDeviceGuard). reset_deviceOptionalCUDAGuard107*da0073e9SAndroid Build Coastguard Worker void reset_device(Device device) { 108*da0073e9SAndroid Build Coastguard Worker guard_.reset_device(device); 109*da0073e9SAndroid Build Coastguard Worker } 110*da0073e9SAndroid Build Coastguard Worker 111*da0073e9SAndroid Build Coastguard Worker /// Sets the CUDA device to the given device index, initializing the guard if 112*da0073e9SAndroid Build Coastguard Worker /// it is not already initialized. set_indexOptionalCUDAGuard113*da0073e9SAndroid Build Coastguard Worker void set_index(DeviceIndex device_index) { 114*da0073e9SAndroid Build Coastguard Worker guard_.set_index(device_index); 115*da0073e9SAndroid Build Coastguard Worker } 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker /// Returns the device that was set immediately prior to initialization of the 118*da0073e9SAndroid Build Coastguard Worker /// guard, or nullopt if the guard is uninitialized. original_deviceOptionalCUDAGuard119*da0073e9SAndroid Build Coastguard Worker std::optional<Device> original_device() const { 120*da0073e9SAndroid Build Coastguard Worker return guard_.original_device(); 121*da0073e9SAndroid Build Coastguard Worker } 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker /// Returns the most recent device that was set using this device guard, 124*da0073e9SAndroid Build Coastguard Worker /// either from construction, or via set_device, if the guard is initialized, 125*da0073e9SAndroid Build Coastguard Worker /// or nullopt if the guard is uninitialized. current_deviceOptionalCUDAGuard126*da0073e9SAndroid Build Coastguard Worker std::optional<Device> current_device() const { 127*da0073e9SAndroid Build Coastguard Worker return guard_.current_device(); 128*da0073e9SAndroid Build Coastguard Worker } 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker /// Restore the original CUDA device, resetting this guard to uninitialized 131*da0073e9SAndroid Build Coastguard Worker /// state. resetOptionalCUDAGuard132*da0073e9SAndroid Build Coastguard Worker void reset() { 133*da0073e9SAndroid Build Coastguard Worker guard_.reset(); 134*da0073e9SAndroid Build Coastguard Worker } 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker private: 137*da0073e9SAndroid Build Coastguard Worker c10::impl::InlineOptionalDeviceGuard<impl::CUDAGuardImpl> guard_; 138*da0073e9SAndroid Build Coastguard Worker }; 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker /// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard 141*da0073e9SAndroid Build Coastguard Worker /// for when you can use this. 142*da0073e9SAndroid Build Coastguard Worker struct CUDAStreamGuard { 143*da0073e9SAndroid Build Coastguard Worker /// No default constructor, see Note [Omitted default constructor from RAII] 144*da0073e9SAndroid Build Coastguard Worker explicit CUDAStreamGuard() = delete; 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker /// Set the current CUDA device to the device associated with the passed 147*da0073e9SAndroid Build Coastguard Worker /// stream, and set the current CUDA stream on that device to the passed 148*da0073e9SAndroid Build Coastguard Worker /// stream. Errors if the Stream is not a CUDA stream. CUDAStreamGuardCUDAStreamGuard149*da0073e9SAndroid Build Coastguard Worker explicit CUDAStreamGuard(Stream stream) : guard_(stream) {} 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker /// Copy is disallowed 152*da0073e9SAndroid Build Coastguard Worker CUDAStreamGuard(const CUDAStreamGuard&) = delete; 153*da0073e9SAndroid Build Coastguard Worker CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; 154*da0073e9SAndroid Build Coastguard Worker 155*da0073e9SAndroid Build Coastguard Worker /// Move is disallowed, as CUDAStreamGuard does not have an uninitialized 156*da0073e9SAndroid Build Coastguard Worker /// state, which is required for moves on types with nontrivial destructors. 157*da0073e9SAndroid Build Coastguard Worker CUDAStreamGuard(CUDAStreamGuard&& other) = delete; 158*da0073e9SAndroid Build Coastguard Worker CUDAStreamGuard& operator=(CUDAStreamGuard&& other) = delete; 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker /// Resets the currently set stream to the original stream and 161*da0073e9SAndroid Build Coastguard Worker /// the currently set device to the original device. Then, 162*da0073e9SAndroid Build Coastguard Worker /// set the current device to the device associated with the passed stream, 163*da0073e9SAndroid Build Coastguard Worker /// and set the current stream on that device to the passed stream. 164*da0073e9SAndroid Build Coastguard Worker /// Errors if the stream passed is not a CUDA stream. 165*da0073e9SAndroid Build Coastguard Worker /// 166*da0073e9SAndroid Build Coastguard Worker /// NOTE: this implementation may skip some stream/device setting if 167*da0073e9SAndroid Build Coastguard Worker /// it can prove that it is unnecessary. 168*da0073e9SAndroid Build Coastguard Worker /// 169*da0073e9SAndroid Build Coastguard Worker /// WARNING: reset_stream does NOT preserve previously set streams on 170*da0073e9SAndroid Build Coastguard Worker /// different devices. If you need to set streams on multiple devices 171*da0073e9SAndroid Build Coastguard Worker /// on CUDA, use CUDAMultiStreamGuard instead. reset_streamCUDAStreamGuard172*da0073e9SAndroid Build Coastguard Worker void reset_stream(Stream stream) { 173*da0073e9SAndroid Build Coastguard Worker guard_.reset_stream(stream); 174*da0073e9SAndroid Build Coastguard Worker } 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker /// Returns the CUDA stream that was set at the time the guard was 177*da0073e9SAndroid Build Coastguard Worker /// constructed. original_streamCUDAStreamGuard178*da0073e9SAndroid Build Coastguard Worker CUDAStream original_stream() const { 179*da0073e9SAndroid Build Coastguard Worker return CUDAStream(CUDAStream::UNCHECKED, guard_.original_stream()); 180*da0073e9SAndroid Build Coastguard Worker } 181*da0073e9SAndroid Build Coastguard Worker 182*da0073e9SAndroid Build Coastguard Worker /// Returns the most recent CUDA stream that was set using this device guard, 183*da0073e9SAndroid Build Coastguard Worker /// either from construction, or via set_stream. current_streamCUDAStreamGuard184*da0073e9SAndroid Build Coastguard Worker CUDAStream current_stream() const { 185*da0073e9SAndroid Build Coastguard Worker return CUDAStream(CUDAStream::UNCHECKED, guard_.current_stream()); 186*da0073e9SAndroid Build Coastguard Worker } 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker /// Returns the most recent CUDA device that was set using this device guard, 189*da0073e9SAndroid Build Coastguard Worker /// either from construction, or via set_device/reset_device/set_index. current_deviceCUDAStreamGuard190*da0073e9SAndroid Build Coastguard Worker Device current_device() const { 191*da0073e9SAndroid Build Coastguard Worker return guard_.current_device(); 192*da0073e9SAndroid Build Coastguard Worker } 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker /// Returns the CUDA device that was set at the most recent reset_stream(), 195*da0073e9SAndroid Build Coastguard Worker /// or otherwise the device at construction time. original_deviceCUDAStreamGuard196*da0073e9SAndroid Build Coastguard Worker Device original_device() const { 197*da0073e9SAndroid Build Coastguard Worker return guard_.original_device(); 198*da0073e9SAndroid Build Coastguard Worker } 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker private: 201*da0073e9SAndroid Build Coastguard Worker c10::impl::InlineStreamGuard<impl::CUDAGuardImpl> guard_; 202*da0073e9SAndroid Build Coastguard Worker }; 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker /// A variant of OptionalStreamGuard that is specialized for CUDA. See 205*da0073e9SAndroid Build Coastguard Worker /// CUDAGuard for when you can use this. 206*da0073e9SAndroid Build Coastguard Worker struct OptionalCUDAStreamGuard { 207*da0073e9SAndroid Build Coastguard Worker /// Create an uninitialized guard. OptionalCUDAStreamGuardOptionalCUDAStreamGuard208*da0073e9SAndroid Build Coastguard Worker explicit OptionalCUDAStreamGuard() : guard_() {} 209*da0073e9SAndroid Build Coastguard Worker 210*da0073e9SAndroid Build Coastguard Worker /// Set the current CUDA device to the device associated with the passed 211*da0073e9SAndroid Build Coastguard Worker /// stream, and set the current CUDA stream on that device to the passed 212*da0073e9SAndroid Build Coastguard Worker /// stream. Errors if the Stream is not a CUDA stream. OptionalCUDAStreamGuardOptionalCUDAStreamGuard213*da0073e9SAndroid Build Coastguard Worker explicit OptionalCUDAStreamGuard(Stream stream) : guard_(stream) {} 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker /// Set the current device to the device associated with the passed stream, 216*da0073e9SAndroid Build Coastguard Worker /// and set the current stream on that device to the passed stream, 217*da0073e9SAndroid Build Coastguard Worker /// if the passed stream is not nullopt. OptionalCUDAStreamGuardOptionalCUDAStreamGuard218*da0073e9SAndroid Build Coastguard Worker explicit OptionalCUDAStreamGuard(std::optional<Stream> stream_opt) 219*da0073e9SAndroid Build Coastguard Worker : guard_(stream_opt) {} 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker /// Copy is disallowed 222*da0073e9SAndroid Build Coastguard Worker OptionalCUDAStreamGuard(const OptionalCUDAStreamGuard&) = delete; 223*da0073e9SAndroid Build Coastguard Worker OptionalCUDAStreamGuard& operator=(const OptionalCUDAStreamGuard&) = delete; 224*da0073e9SAndroid Build Coastguard Worker 225*da0073e9SAndroid Build Coastguard Worker // See Note [Move construction for RAII guards is tricky] 226*da0073e9SAndroid Build Coastguard Worker OptionalCUDAStreamGuard(OptionalCUDAStreamGuard&& other) = delete; 227*da0073e9SAndroid Build Coastguard Worker 228*da0073e9SAndroid Build Coastguard Worker // See Note [Move assignment for RAII guards is tricky] 229*da0073e9SAndroid Build Coastguard Worker OptionalCUDAStreamGuard& operator=(OptionalCUDAStreamGuard&& other) = delete; 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker /// Resets the currently set CUDA stream to the original stream and 232*da0073e9SAndroid Build Coastguard Worker /// the currently set device to the original device. Then, 233*da0073e9SAndroid Build Coastguard Worker /// set the current device to the device associated with the passed stream, 234*da0073e9SAndroid Build Coastguard Worker /// and set the current stream on that device to the passed stream. 235*da0073e9SAndroid Build Coastguard Worker /// Initializes the guard if it was not previously initialized. reset_streamOptionalCUDAStreamGuard236*da0073e9SAndroid Build Coastguard Worker void reset_stream(Stream stream) { 237*da0073e9SAndroid Build Coastguard Worker guard_.reset_stream(stream); 238*da0073e9SAndroid Build Coastguard Worker } 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker /// Returns the CUDA stream that was set at the time the guard was most 241*da0073e9SAndroid Build Coastguard Worker /// recently initialized, or nullopt if the guard is uninitialized. original_streamOptionalCUDAStreamGuard242*da0073e9SAndroid Build Coastguard Worker std::optional<CUDAStream> original_stream() const { 243*da0073e9SAndroid Build Coastguard Worker auto r = guard_.original_stream(); 244*da0073e9SAndroid Build Coastguard Worker if (r.has_value()) { 245*da0073e9SAndroid Build Coastguard Worker return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); 246*da0073e9SAndroid Build Coastguard Worker } else { 247*da0073e9SAndroid Build Coastguard Worker return std::nullopt; 248*da0073e9SAndroid Build Coastguard Worker } 249*da0073e9SAndroid Build Coastguard Worker } 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker /// Returns the most recent CUDA stream that was set using this stream guard, 252*da0073e9SAndroid Build Coastguard Worker /// either from construction, or via reset_stream, if the guard is 253*da0073e9SAndroid Build Coastguard Worker /// initialized, or nullopt if the guard is uninitialized. current_streamOptionalCUDAStreamGuard254*da0073e9SAndroid Build Coastguard Worker std::optional<CUDAStream> current_stream() const { 255*da0073e9SAndroid Build Coastguard Worker auto r = guard_.current_stream(); 256*da0073e9SAndroid Build Coastguard Worker if (r.has_value()) { 257*da0073e9SAndroid Build Coastguard Worker return std::make_optional(CUDAStream(CUDAStream::UNCHECKED, r.value())); 258*da0073e9SAndroid Build Coastguard Worker } else { 259*da0073e9SAndroid Build Coastguard Worker return std::nullopt; 260*da0073e9SAndroid Build Coastguard Worker } 261*da0073e9SAndroid Build Coastguard Worker } 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker /// Restore the original CUDA device and stream, resetting this guard to 264*da0073e9SAndroid Build Coastguard Worker /// uninitialized state. resetOptionalCUDAStreamGuard265*da0073e9SAndroid Build Coastguard Worker void reset() { 266*da0073e9SAndroid Build Coastguard Worker guard_.reset(); 267*da0073e9SAndroid Build Coastguard Worker } 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker private: 270*da0073e9SAndroid Build Coastguard Worker c10::impl::InlineOptionalStreamGuard<impl::CUDAGuardImpl> guard_; 271*da0073e9SAndroid Build Coastguard Worker }; 272*da0073e9SAndroid Build Coastguard Worker 273*da0073e9SAndroid Build Coastguard Worker /// A variant of MultiStreamGuard that is specialized for CUDA. 274*da0073e9SAndroid Build Coastguard Worker struct CUDAMultiStreamGuard { CUDAMultiStreamGuardCUDAMultiStreamGuard275*da0073e9SAndroid Build Coastguard Worker explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams) 276*da0073e9SAndroid Build Coastguard Worker : guard_(unwrapStreams(streams)) {} 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker /// Copy is disallowed 279*da0073e9SAndroid Build Coastguard Worker CUDAMultiStreamGuard(const CUDAMultiStreamGuard&) = delete; 280*da0073e9SAndroid Build Coastguard Worker CUDAMultiStreamGuard& operator=(const CUDAMultiStreamGuard&) = delete; 281*da0073e9SAndroid Build Coastguard Worker 282*da0073e9SAndroid Build Coastguard Worker // See Note [Move construction for RAII guards is tricky] 283*da0073e9SAndroid Build Coastguard Worker CUDAMultiStreamGuard(CUDAMultiStreamGuard&& other) = delete; 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker // See Note [Move assignment for RAII guards is tricky] 286*da0073e9SAndroid Build Coastguard Worker CUDAMultiStreamGuard& operator=(CUDAMultiStreamGuard&& other) = delete; 287*da0073e9SAndroid Build Coastguard Worker 288*da0073e9SAndroid Build Coastguard Worker private: 289*da0073e9SAndroid Build Coastguard Worker c10::impl::InlineMultiStreamGuard<impl::CUDAGuardImpl> guard_; 290*da0073e9SAndroid Build Coastguard Worker unwrapStreamsCUDAMultiStreamGuard291*da0073e9SAndroid Build Coastguard Worker static std::vector<Stream> unwrapStreams(ArrayRef<CUDAStream> cudaStreams) { 292*da0073e9SAndroid Build Coastguard Worker std::vector<Stream> streams; 293*da0073e9SAndroid Build Coastguard Worker streams.reserve(cudaStreams.size()); 294*da0073e9SAndroid Build Coastguard Worker for (const CUDAStream& cudaStream : cudaStreams) { 295*da0073e9SAndroid Build Coastguard Worker streams.push_back(cudaStream); 296*da0073e9SAndroid Build Coastguard Worker } 297*da0073e9SAndroid Build Coastguard Worker return streams; 298*da0073e9SAndroid Build Coastguard Worker } 299*da0073e9SAndroid Build Coastguard Worker }; 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda 302