1 #include <ATen/cudnn/Types.h> 2 3 #include <ATen/ATen.h> 4 5 namespace at::native { 6 getCudnnDataTypeFromScalarType(const at::ScalarType dtype)7cudnnDataType_t getCudnnDataTypeFromScalarType(const at::ScalarType dtype) { 8 if (dtype == c10::kQInt8) { 9 return CUDNN_DATA_INT8; 10 } else if (dtype == at::kFloat) { 11 return CUDNN_DATA_FLOAT; 12 } else if (dtype == at::kDouble) { 13 return CUDNN_DATA_DOUBLE; 14 } else if (dtype == at::kHalf) { 15 return CUDNN_DATA_HALF; 16 } else if (dtype == at::kBFloat16) { 17 return CUDNN_DATA_BFLOAT16; 18 } else if (dtype == at::kInt) { 19 return CUDNN_DATA_INT32; 20 } else if (dtype == at::kByte) { 21 return CUDNN_DATA_UINT8; 22 } else if (dtype == at::kChar) { 23 return CUDNN_DATA_INT8; 24 } 25 std::string msg("getCudnnDataTypeFromScalarType() not supported for "); 26 msg += toString(dtype); 27 throw std::runtime_error(msg); 28 } 29 getCudnnDataType(const at::Tensor & tensor)30cudnnDataType_t getCudnnDataType(const at::Tensor& tensor) { 31 return getCudnnDataTypeFromScalarType(tensor.scalar_type()); 32 } 33 cudnn_version()34int64_t cudnn_version() { 35 return CUDNN_VERSION; 36 } 37 38 } // namespace at::native 39