1 #include <ATen/cuda/CUDAContext.h> 2 #include <ATen/cuda/detail/DeviceThreadHandles.h> 3 4 #if defined(USE_CUDSS) 5 6 namespace at::cuda { 7 namespace { 8 createCudssHandle(cudssHandle_t * handle)9void createCudssHandle(cudssHandle_t *handle) { 10 TORCH_CUDSS_CHECK(cudssCreate(handle)); 11 } 12 destroyCudssHandle(cudssHandle_t handle)13void destroyCudssHandle(cudssHandle_t handle) { 14 // this is because of something dumb in the ordering of 15 // destruction. Sometimes atexit, the cuda context (or something) 16 // would already be destroyed by the time this gets destroyed. It 17 // happens in fbcode setting. @colesbury and @soumith decided to not destroy 18 // the handle as a workaround. 19 // - Comments of @soumith copied from cuDNN handle pool implementation 20 #ifdef NO_CUDNN_DESTROY_HANDLE 21 (void)handle; // Suppress unused variable warning 22 #else 23 cudssDestroy(handle); 24 #endif 25 } 26 27 using CudssPoolType = DeviceThreadHandlePool<cudssHandle_t, createCudssHandle, destroyCudssHandle>; 28 29 } // namespace 30 getCurrentCudssHandle()31cudssHandle_t getCurrentCudssHandle() { 32 c10::DeviceIndex device = 0; 33 AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); 34 35 // Thread local PoolWindows are lazily-initialized 36 // to avoid initialization issues that caused hangs on Windows. 37 // See: https://github.com/pytorch/pytorch/pull/22405 38 // This thread local unique_ptrs will be destroyed when the thread terminates, 39 // releasing its reserved handles back to the pool. 40 static auto pool = std::make_shared<CudssPoolType>(); 41 thread_local std::unique_ptr<CudssPoolType::PoolWindow> myPoolWindow( 42 pool->newPoolWindow()); 43 44 auto handle = myPoolWindow->reserve(device); 45 auto stream = c10::cuda::getCurrentCUDAStream(); 46 TORCH_CUDSS_CHECK(cudssSetStream(handle, stream)); 47 return handle; 48 } 49 50 } // namespace at::cuda 51 52 #endif 53