xref: /aosp_15_r20/external/pytorch/aten/src/ATen/test/cuda_complex_math_test.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <gtest/gtest.h>
2 #include <c10/cuda/CUDAException.h>
3 
safeDeviceCount()4 int safeDeviceCount() {
5   int count;
6   cudaError_t err = cudaGetDeviceCount(&count);
7   if (err == cudaErrorInsufficientDriver || err == cudaErrorNoDevice) {
8     return 0;
9   }
10   return count;
11 }
12 
13 #define SKIP_IF_NO_GPU()                    \
14   do {                                      \
15     if (safeDeviceCount() == 0) {           \
16       return;                               \
17     }                                       \
18   } while(0)
19 
20 #define C10_ASSERT_NEAR(a, b, tol) assert(abs(a - b) < tol)
21 #define C10_DEFINE_TEST(a, b)                       \
22 __global__ void CUDA##a##b();                       \
23 TEST(a##Device, b) {                                \
24   SKIP_IF_NO_GPU();                                 \
25   cudaDeviceSynchronize();                          \
26   CUDA##a##b<<<1, 1>>>();                           \
27   C10_CUDA_KERNEL_LAUNCH_CHECK();                   \
28   cudaDeviceSynchronize();                          \
29   ASSERT_EQ(cudaGetLastError(), cudaSuccess);       \
30 }                                                   \
31 __global__ void CUDA##a##b()
32 #include <c10/test/util/complex_math_test_common.h>
33 
34 
35 #undef C10_DEFINE_TEST
36 #undef C10_ASSERT_NEAR
37 #define C10_DEFINE_TEST(a, b) TEST(a##Host, b)
38 #define C10_ASSERT_NEAR(a, b, tol) ASSERT_NEAR(a, b, tol)
39 #include <c10/test/util/complex_math_test_common.h>
40