xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/linalg/CudssHandlePool.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/CUDAContext.h>
2 #include <ATen/cuda/detail/DeviceThreadHandles.h>
3 
4 #if defined(USE_CUDSS)
5 
6 namespace at::cuda {
7 namespace {
8 
createCudssHandle(cudssHandle_t * handle)9 void createCudssHandle(cudssHandle_t *handle) {
10   TORCH_CUDSS_CHECK(cudssCreate(handle));
11 }
12 
destroyCudssHandle(cudssHandle_t handle)13 void destroyCudssHandle(cudssHandle_t handle) {
14 // this is because of something dumb in the ordering of
15 // destruction. Sometimes atexit, the cuda context (or something)
16 // would already be destroyed by the time this gets destroyed. It
17 // happens in fbcode setting. @colesbury and @soumith decided to not destroy
18 // the handle as a workaround.
19 //   - Comments of @soumith copied from cuDNN handle pool implementation
20 #ifdef NO_CUDNN_DESTROY_HANDLE
21   (void)handle; // Suppress unused variable warning
22 #else
23     cudssDestroy(handle);
24 #endif
25 }
26 
27 using CudssPoolType = DeviceThreadHandlePool<cudssHandle_t, createCudssHandle, destroyCudssHandle>;
28 
29 } // namespace
30 
getCurrentCudssHandle()31 cudssHandle_t getCurrentCudssHandle() {
32   c10::DeviceIndex device = 0;
33   AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
34 
35   // Thread local PoolWindows are lazily-initialized
36   // to avoid initialization issues that caused hangs on Windows.
37   // See: https://github.com/pytorch/pytorch/pull/22405
38   // This thread local unique_ptrs will be destroyed when the thread terminates,
39   // releasing its reserved handles back to the pool.
40   static auto pool = std::make_shared<CudssPoolType>();
41   thread_local std::unique_ptr<CudssPoolType::PoolWindow> myPoolWindow(
42       pool->newPoolWindow());
43 
44   auto handle = myPoolWindow->reserve(device);
45   auto stream = c10::cuda::getCurrentCUDAStream();
46   TORCH_CUDSS_CHECK(cudssSetStream(handle, stream));
47   return handle;
48 }
49 
50 } // namespace at::cuda
51 
52 #endif
53