1 #pragma once 2 3 #include <ATen/Config.h> 4 5 #include <string> 6 #include <stdexcept> 7 #include <sstream> 8 #include <cufft.h> 9 #include <cufftXt.h> 10 11 namespace at { namespace native { 12 13 // This means that max dim is 3 + 2 = 5 with batch dimension and possible 14 // complex dimension 15 constexpr int max_rank = 3; 16 _cudaGetErrorEnum(cufftResult error)17static inline std::string _cudaGetErrorEnum(cufftResult error) 18 { 19 switch (error) 20 { 21 case CUFFT_SUCCESS: 22 return "CUFFT_SUCCESS"; 23 case CUFFT_INVALID_PLAN: 24 return "CUFFT_INVALID_PLAN"; 25 case CUFFT_ALLOC_FAILED: 26 return "CUFFT_ALLOC_FAILED"; 27 case CUFFT_INVALID_TYPE: 28 return "CUFFT_INVALID_TYPE"; 29 case CUFFT_INVALID_VALUE: 30 return "CUFFT_INVALID_VALUE"; 31 case CUFFT_INTERNAL_ERROR: 32 return "CUFFT_INTERNAL_ERROR"; 33 case CUFFT_EXEC_FAILED: 34 return "CUFFT_EXEC_FAILED"; 35 case CUFFT_SETUP_FAILED: 36 return "CUFFT_SETUP_FAILED"; 37 case CUFFT_INVALID_SIZE: 38 return "CUFFT_INVALID_SIZE"; 39 case CUFFT_UNALIGNED_DATA: 40 return "CUFFT_UNALIGNED_DATA"; 41 case CUFFT_INCOMPLETE_PARAMETER_LIST: 42 return "CUFFT_INCOMPLETE_PARAMETER_LIST"; 43 case CUFFT_INVALID_DEVICE: 44 return "CUFFT_INVALID_DEVICE"; 45 case CUFFT_PARSE_ERROR: 46 return "CUFFT_PARSE_ERROR"; 47 case CUFFT_NO_WORKSPACE: 48 return "CUFFT_NO_WORKSPACE"; 49 case CUFFT_NOT_IMPLEMENTED: 50 return "CUFFT_NOT_IMPLEMENTED"; 51 #if !defined(USE_ROCM) 52 case CUFFT_LICENSE_ERROR: 53 return "CUFFT_LICENSE_ERROR"; 54 #endif 55 case CUFFT_NOT_SUPPORTED: 56 return "CUFFT_NOT_SUPPORTED"; 57 default: 58 std::ostringstream ss; 59 ss << "unknown error " << error; 60 return ss.str(); 61 } 62 } 63 CUFFT_CHECK(cufftResult error)64static inline void CUFFT_CHECK(cufftResult error) 65 { 66 if (error != CUFFT_SUCCESS) { 67 std::ostringstream ss; 68 ss << "cuFFT error: " << _cudaGetErrorEnum(error); 69 AT_ERROR(ss.str()); 70 } 71 } 72 73 }} // at::native 74