1 #pragma once 2 3 #include <ATen/Tensor.h> 4 #include <ATen/cudnn/cudnn-wrapper.h> 5 6 namespace at::native { 7 8 TORCH_CUDA_CPP_API cudnnDataType_t 9 getCudnnDataTypeFromScalarType(const at::ScalarType dtype); 10 cudnnDataType_t getCudnnDataType(const at::Tensor& tensor); 11 12 int64_t cudnn_version(); 13 14 } // namespace at::native 15