xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/CuFFTUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/Config.h>
4 
5 #include <string>
6 #include <stdexcept>
7 #include <sstream>
8 #include <cufft.h>
9 #include <cufftXt.h>
10 
11 namespace at { namespace native {
12 
13 // This means that max dim is 3 + 2 = 5 with batch dimension and possible
14 // complex dimension
15 constexpr int max_rank = 3;
16 
_cudaGetErrorEnum(cufftResult error)17 static inline std::string _cudaGetErrorEnum(cufftResult error)
18 {
19   switch (error)
20   {
21     case CUFFT_SUCCESS:
22       return "CUFFT_SUCCESS";
23     case CUFFT_INVALID_PLAN:
24       return "CUFFT_INVALID_PLAN";
25     case CUFFT_ALLOC_FAILED:
26       return "CUFFT_ALLOC_FAILED";
27     case CUFFT_INVALID_TYPE:
28       return "CUFFT_INVALID_TYPE";
29     case CUFFT_INVALID_VALUE:
30       return "CUFFT_INVALID_VALUE";
31     case CUFFT_INTERNAL_ERROR:
32       return "CUFFT_INTERNAL_ERROR";
33     case CUFFT_EXEC_FAILED:
34       return "CUFFT_EXEC_FAILED";
35     case CUFFT_SETUP_FAILED:
36       return "CUFFT_SETUP_FAILED";
37     case CUFFT_INVALID_SIZE:
38       return "CUFFT_INVALID_SIZE";
39     case CUFFT_UNALIGNED_DATA:
40       return "CUFFT_UNALIGNED_DATA";
41     case CUFFT_INCOMPLETE_PARAMETER_LIST:
42       return "CUFFT_INCOMPLETE_PARAMETER_LIST";
43     case CUFFT_INVALID_DEVICE:
44       return "CUFFT_INVALID_DEVICE";
45     case CUFFT_PARSE_ERROR:
46       return "CUFFT_PARSE_ERROR";
47     case CUFFT_NO_WORKSPACE:
48       return "CUFFT_NO_WORKSPACE";
49     case CUFFT_NOT_IMPLEMENTED:
50       return "CUFFT_NOT_IMPLEMENTED";
51 #if !defined(USE_ROCM)
52     case CUFFT_LICENSE_ERROR:
53       return "CUFFT_LICENSE_ERROR";
54 #endif
55     case CUFFT_NOT_SUPPORTED:
56       return "CUFFT_NOT_SUPPORTED";
57     default:
58       std::ostringstream ss;
59       ss << "unknown error " << error;
60       return ss.str();
61   }
62 }
63 
CUFFT_CHECK(cufftResult error)64 static inline void CUFFT_CHECK(cufftResult error)
65 {
66   if (error != CUFFT_SUCCESS) {
67     std::ostringstream ss;
68     ss << "cuFFT error: " << _cudaGetErrorEnum(error);
69     AT_ERROR(ss.str());
70   }
71 }
72 
73 }} // at::native
74