1 #include <ATen/cuda/CUDAContext.h> 2 #include <ATen/cuda/detail/DeviceThreadHandles.h> 3 4 #if defined(CUDART_VERSION) || defined(USE_ROCM) 5 6 namespace at::cuda { 7 namespace { 8 createCusolverDnHandle(cusolverDnHandle_t * handle)9void createCusolverDnHandle(cusolverDnHandle_t *handle) { 10 TORCH_CUSOLVER_CHECK(cusolverDnCreate(handle)); 11 } 12 destroyCusolverDnHandle(cusolverDnHandle_t handle)13void destroyCusolverDnHandle(cusolverDnHandle_t handle) { 14 // this is because of something dumb in the ordering of 15 // destruction. Sometimes atexit, the cuda context (or something) 16 // would already be destroyed by the time this gets destroyed. It 17 // happens in fbcode setting. @colesbury and @soumith decided to not destroy 18 // the handle as a workaround. 19 // - Comments of @soumith copied from cuDNN handle pool implementation 20 #ifdef NO_CUDNN_DESTROY_HANDLE 21 (void)handle; // Suppress unused variable warning 22 #else 23 cusolverDnDestroy(handle); 24 #endif 25 } 26 27 using CuSolverDnPoolType = DeviceThreadHandlePool<cusolverDnHandle_t, createCusolverDnHandle, destroyCusolverDnHandle>; 28 29 } // namespace 30 getCurrentCUDASolverDnHandle()31cusolverDnHandle_t getCurrentCUDASolverDnHandle() { 32 c10::DeviceIndex device = 0; 33 AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); 34 35 // Thread local PoolWindows are lazily-initialized 36 // to avoid initialization issues that caused hangs on Windows. 37 // See: https://github.com/pytorch/pytorch/pull/22405 38 // This thread local unique_ptrs will be destroyed when the thread terminates, 39 // releasing its reserved handles back to the pool. 40 static auto pool = std::make_shared<CuSolverDnPoolType>(); 41 thread_local std::unique_ptr<CuSolverDnPoolType::PoolWindow> myPoolWindow( 42 pool->newPoolWindow()); 43 44 auto handle = myPoolWindow->reserve(device); 45 auto stream = c10::cuda::getCurrentCUDAStream(); 46 TORCH_CUSOLVER_CHECK(cusolverDnSetStream(handle, stream)); 47 return handle; 48 } 49 50 } // namespace at::cuda 51 52 #endif // CUDART_VERSION 53