xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CuSparseHandlePool.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/cuda/CUDAContext.h>
2 #include <ATen/cuda/detail/DeviceThreadHandles.h>
3 
4 namespace at::cuda {
5 namespace {
6 
createCusparseHandle(cusparseHandle_t * handle)7 void createCusparseHandle(cusparseHandle_t *handle) {
8   TORCH_CUDASPARSE_CHECK(cusparseCreate(handle));
9 }
10 
destroyCusparseHandle(cusparseHandle_t handle)11 void destroyCusparseHandle(cusparseHandle_t handle) {
12 // this is because of something dumb in the ordering of
13 // destruction. Sometimes atexit, the cuda context (or something)
14 // would already be destroyed by the time this gets destroyed. It
15 // happens in fbcode setting. @colesbury and @soumith decided to not destroy
16 // the handle as a workaround.
17 //   - Comments of @soumith copied from cuDNN handle pool implementation
18 #ifdef NO_CUDNN_DESTROY_HANDLE
19 #else
20     cusparseDestroy(handle);
21 #endif
22 }
23 
24 using CuSparsePoolType = DeviceThreadHandlePool<cusparseHandle_t, createCusparseHandle, destroyCusparseHandle>;
25 
26 } // namespace
27 
getCurrentCUDASparseHandle()28 cusparseHandle_t getCurrentCUDASparseHandle() {
29   c10::DeviceIndex device = 0;
30   AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
31 
32   // Thread local PoolWindows are lazily-initialized
33   // to avoid initialization issues that caused hangs on Windows.
34   // See: https://github.com/pytorch/pytorch/pull/22405
35   // This thread local unique_ptrs will be destroyed when the thread terminates,
36   // releasing its reserved handles back to the pool.
37   static auto pool = std::make_shared<CuSparsePoolType>();
38   thread_local std::unique_ptr<CuSparsePoolType::PoolWindow> myPoolWindow(
39       pool->newPoolWindow());
40 
41   auto handle = myPoolWindow->reserve(device);
42   TORCH_CUDASPARSE_CHECK(cusparseSetStream(handle, c10::cuda::getCurrentCUDAStream()));
43   return handle;
44 }
45 
46 } // namespace at::cuda
47