#pragma once #include #include #include // CUDA Graphs utils used by c10 and aten. // aten/cuda/CUDAGraphsUtils.cuh adds utils used by aten only. namespace c10::cuda { using CaptureId_t = unsigned long long; // first is set if the instance is created by CUDAGraph::capture_begin. // second is set if the instance is created by at::cuda::graph_pool_handle. using MempoolId_t = std::pair; // RAII guard for "cudaStreamCaptureMode", a thread-local value // that controls the error-checking strictness of a capture. struct C10_CUDA_API CUDAStreamCaptureModeGuard { CUDAStreamCaptureModeGuard(cudaStreamCaptureMode desired) : strictness_(desired) { C10_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&strictness_)); } ~CUDAStreamCaptureModeGuard() { C10_CUDA_CHECK_WARN(cudaThreadExchangeStreamCaptureMode(&strictness_)); } private: cudaStreamCaptureMode strictness_; }; // Protects against enum cudaStreamCaptureStatus implementation changes. // Some compilers seem not to like static_assert without the messages. static_assert( int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone) == 0, "unexpected int(cudaStreamCaptureStatusNone) value"); static_assert( int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive) == 1, "unexpected int(cudaStreamCaptureStatusActive) value"); static_assert( int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) == 2, "unexpected int(cudaStreamCaptureStatusInvalidated) value"); enum class CaptureStatus : int { None = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusNone), Active = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusActive), Invalidated = int(cudaStreamCaptureStatus::cudaStreamCaptureStatusInvalidated) }; inline std::ostream& operator<<(std::ostream& os, CaptureStatus status) { switch (status) { case CaptureStatus::None: os << "cudaStreamCaptureStatusNone"; break; case CaptureStatus::Active: os << "cudaStreamCaptureStatusActive"; break; case CaptureStatus::Invalidated: os << "cudaStreamCaptureStatusInvalidated"; break; default: TORCH_INTERNAL_ASSERT( false, "Unknown CUDA graph CaptureStatus", int(status)); } return os; } // Use this version where you're sure a CUDA context exists already. inline CaptureStatus currentStreamCaptureStatusMayInitCtx() { cudaStreamCaptureStatus is_capturing{cudaStreamCaptureStatusNone}; C10_CUDA_CHECK( cudaStreamIsCapturing(c10::cuda::getCurrentCUDAStream(), &is_capturing)); return CaptureStatus(is_capturing); } } // namespace c10::cuda