xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/detail/CUDAHooks.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/detail/CUDAHooksInterface.h>
4 
5 #include <ATen/Generator.h>
6 #include <optional>
7 
8 // TODO: No need to have this whole header, we can just put it all in
9 // the cpp file
10 
11 namespace at::cuda::detail {
12 
13 // Set the callback to initialize Magma, which is set by
14 // torch_cuda_cu. This indirection is required so magma_init is called
15 // in the same library where Magma will be used.
16 TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)());
17 
18 
19 // The real implementation of CUDAHooksInterface
20 struct CUDAHooks : public at::CUDAHooksInterface {
CUDAHooksCUDAHooks21   CUDAHooks(at::CUDAHooksArgs) {}
22   void initCUDA() const override;
23   Device getDeviceFromPtr(void* data) const override;
24   bool isPinnedPtr(const void* data) const override;
25   const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
26   bool hasCUDA() const override;
27   bool hasMAGMA() const override;
28   bool hasCuDNN() const override;
29   bool hasCuSOLVER() const override;
30   bool hasCuBLASLt() const override;
31   bool hasROCM() const override;
32   const at::cuda::NVRTC& nvrtc() const override;
33   DeviceIndex current_device() const override;
34   bool hasPrimaryContext(DeviceIndex device_index) const override;
35   Allocator* getCUDADeviceAllocator() const override;
36   Allocator* getPinnedMemoryAllocator() const override;
37   bool compiledWithCuDNN() const override;
38   bool compiledWithMIOpen() const override;
39   bool supportsDilatedConvolutionWithCuDNN() const override;
40   bool supportsDepthwiseConvolutionWithCuDNN() const override;
41   bool supportsBFloat16ConvolutionWithCuDNNv8() const override;
42   bool hasCUDART() const override;
43   long versionCUDART() const override;
44   long versionCuDNN() const override;
45   std::string showConfig() const override;
46   double batchnormMinEpsilonCuDNN() const override;
47   int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override;
48   void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override;
49   int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override;
50   void cuFFTClearPlanCache(DeviceIndex device_index) const override;
51   int getNumGPUs() const override;
52 #ifdef USE_ROCM
53   bool isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const override;
54 #endif
55   void deviceSynchronize(DeviceIndex device_index) const override;
56 };
57 
58 } // at::cuda::detail
59