1 #pragma once 2 3 #include <ATen/detail/CUDAHooksInterface.h> 4 5 #include <ATen/Generator.h> 6 #include <optional> 7 8 // TODO: No need to have this whole header, we can just put it all in 9 // the cpp file 10 11 namespace at::cuda::detail { 12 13 // Set the callback to initialize Magma, which is set by 14 // torch_cuda_cu. This indirection is required so magma_init is called 15 // in the same library where Magma will be used. 16 TORCH_CUDA_CPP_API void set_magma_init_fn(void (*magma_init_fn)()); 17 18 19 // The real implementation of CUDAHooksInterface 20 struct CUDAHooks : public at::CUDAHooksInterface { CUDAHooksCUDAHooks21 CUDAHooks(at::CUDAHooksArgs) {} 22 void initCUDA() const override; 23 Device getDeviceFromPtr(void* data) const override; 24 bool isPinnedPtr(const void* data) const override; 25 const Generator& getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override; 26 bool hasCUDA() const override; 27 bool hasMAGMA() const override; 28 bool hasCuDNN() const override; 29 bool hasCuSOLVER() const override; 30 bool hasCuBLASLt() const override; 31 bool hasROCM() const override; 32 const at::cuda::NVRTC& nvrtc() const override; 33 DeviceIndex current_device() const override; 34 bool hasPrimaryContext(DeviceIndex device_index) const override; 35 Allocator* getCUDADeviceAllocator() const override; 36 Allocator* getPinnedMemoryAllocator() const override; 37 bool compiledWithCuDNN() const override; 38 bool compiledWithMIOpen() const override; 39 bool supportsDilatedConvolutionWithCuDNN() const override; 40 bool supportsDepthwiseConvolutionWithCuDNN() const override; 41 bool supportsBFloat16ConvolutionWithCuDNNv8() const override; 42 bool hasCUDART() const override; 43 long versionCUDART() const override; 44 long versionCuDNN() const override; 45 std::string showConfig() const override; 46 double batchnormMinEpsilonCuDNN() const override; 47 int64_t cuFFTGetPlanCacheMaxSize(DeviceIndex device_index) const override; 48 void cuFFTSetPlanCacheMaxSize(DeviceIndex device_index, int64_t max_size) const override; 49 int64_t cuFFTGetPlanCacheSize(DeviceIndex device_index) const override; 50 void cuFFTClearPlanCache(DeviceIndex device_index) const override; 51 int getNumGPUs() const override; 52 #ifdef USE_ROCM 53 bool isGPUArch(DeviceIndex device_index, const std::vector<std::string>& archs) const override; 54 #endif 55 void deviceSynchronize(DeviceIndex device_index) const override; 56 }; 57 58 } // at::cuda::detail 59