1*da0073e9SAndroid Build Coastguard Worker #pragma once
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Worker /* This file defines math functions compatible across different gpu
4*da0073e9SAndroid Build Coastguard Worker * platforms (currently CUDA and HIP).
5*da0073e9SAndroid Build Coastguard Worker */
6*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDACC__) || defined(__HIPCC__)
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
9*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker #ifdef __HIPCC__
12*da0073e9SAndroid Build Coastguard Worker #define __MATH_FUNCTIONS_DECL__ inline C10_DEVICE
13*da0073e9SAndroid Build Coastguard Worker #else /* __HIPCC__ */
14*da0073e9SAndroid Build Coastguard Worker #ifdef __CUDACC_RTC__
15*da0073e9SAndroid Build Coastguard Worker #define __MATH_FUNCTIONS_DECL__ C10_HOST_DEVICE
16*da0073e9SAndroid Build Coastguard Worker #else /* __CUDACC_RTC__ */
17*da0073e9SAndroid Build Coastguard Worker #define __MATH_FUNCTIONS_DECL__ inline C10_HOST_DEVICE
18*da0073e9SAndroid Build Coastguard Worker #endif /* __CUDACC_RTC__ */
19*da0073e9SAndroid Build Coastguard Worker #endif /* __HIPCC__ */
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda::compat {
22*da0073e9SAndroid Build Coastguard Worker
abs(float x)23*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float abs(float x) {
24*da0073e9SAndroid Build Coastguard Worker return ::fabsf(x);
25*da0073e9SAndroid Build Coastguard Worker }
abs(double x)26*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double abs(double x) {
27*da0073e9SAndroid Build Coastguard Worker return ::fabs(x);
28*da0073e9SAndroid Build Coastguard Worker }
29*da0073e9SAndroid Build Coastguard Worker
exp(float x)30*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float exp(float x) {
31*da0073e9SAndroid Build Coastguard Worker return ::expf(x);
32*da0073e9SAndroid Build Coastguard Worker }
exp(double x)33*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double exp(double x) {
34*da0073e9SAndroid Build Coastguard Worker return ::exp(x);
35*da0073e9SAndroid Build Coastguard Worker }
36*da0073e9SAndroid Build Coastguard Worker
ceil(float x)37*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float ceil(float x) {
38*da0073e9SAndroid Build Coastguard Worker return ::ceilf(x);
39*da0073e9SAndroid Build Coastguard Worker }
ceil(double x)40*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double ceil(double x) {
41*da0073e9SAndroid Build Coastguard Worker return ::ceil(x);
42*da0073e9SAndroid Build Coastguard Worker }
43*da0073e9SAndroid Build Coastguard Worker
copysign(float x,float y)44*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float copysign(float x, float y) {
45*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
46*da0073e9SAndroid Build Coastguard Worker return ::copysignf(x, y);
47*da0073e9SAndroid Build Coastguard Worker #else
48*da0073e9SAndroid Build Coastguard Worker // std::copysign gets ICE/Segfaults with gcc 7.5/8 on arm64
49*da0073e9SAndroid Build Coastguard Worker // (e.g. Jetson), see PyTorch PR #51834
50*da0073e9SAndroid Build Coastguard Worker // This host function needs to be here for the compiler but is never used
51*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
52*da0073e9SAndroid Build Coastguard Worker false, "CUDAMathCompat copysign should not run on the CPU");
53*da0073e9SAndroid Build Coastguard Worker #endif
54*da0073e9SAndroid Build Coastguard Worker }
copysign(double x,double y)55*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double copysign(double x, double y) {
56*da0073e9SAndroid Build Coastguard Worker #if defined(__CUDA_ARCH__) || defined(__HIPCC__)
57*da0073e9SAndroid Build Coastguard Worker return ::copysign(x, y);
58*da0073e9SAndroid Build Coastguard Worker #else
59*da0073e9SAndroid Build Coastguard Worker // see above
60*da0073e9SAndroid Build Coastguard Worker TORCH_INTERNAL_ASSERT(
61*da0073e9SAndroid Build Coastguard Worker false, "CUDAMathCompat copysign should not run on the CPU");
62*da0073e9SAndroid Build Coastguard Worker #endif
63*da0073e9SAndroid Build Coastguard Worker }
64*da0073e9SAndroid Build Coastguard Worker
floor(float x)65*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float floor(float x) {
66*da0073e9SAndroid Build Coastguard Worker return ::floorf(x);
67*da0073e9SAndroid Build Coastguard Worker }
floor(double x)68*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double floor(double x) {
69*da0073e9SAndroid Build Coastguard Worker return ::floor(x);
70*da0073e9SAndroid Build Coastguard Worker }
71*da0073e9SAndroid Build Coastguard Worker
log(float x)72*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float log(float x) {
73*da0073e9SAndroid Build Coastguard Worker return ::logf(x);
74*da0073e9SAndroid Build Coastguard Worker }
log(double x)75*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double log(double x) {
76*da0073e9SAndroid Build Coastguard Worker return ::log(x);
77*da0073e9SAndroid Build Coastguard Worker }
78*da0073e9SAndroid Build Coastguard Worker
log1p(float x)79*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float log1p(float x) {
80*da0073e9SAndroid Build Coastguard Worker return ::log1pf(x);
81*da0073e9SAndroid Build Coastguard Worker }
82*da0073e9SAndroid Build Coastguard Worker
log1p(double x)83*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double log1p(double x) {
84*da0073e9SAndroid Build Coastguard Worker return ::log1p(x);
85*da0073e9SAndroid Build Coastguard Worker }
86*da0073e9SAndroid Build Coastguard Worker
max(float x,float y)87*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float max(float x, float y) {
88*da0073e9SAndroid Build Coastguard Worker return ::fmaxf(x, y);
89*da0073e9SAndroid Build Coastguard Worker }
max(double x,double y)90*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double max(double x, double y) {
91*da0073e9SAndroid Build Coastguard Worker return ::fmax(x, y);
92*da0073e9SAndroid Build Coastguard Worker }
93*da0073e9SAndroid Build Coastguard Worker
min(float x,float y)94*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float min(float x, float y) {
95*da0073e9SAndroid Build Coastguard Worker return ::fminf(x, y);
96*da0073e9SAndroid Build Coastguard Worker }
min(double x,double y)97*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double min(double x, double y) {
98*da0073e9SAndroid Build Coastguard Worker return ::fmin(x, y);
99*da0073e9SAndroid Build Coastguard Worker }
100*da0073e9SAndroid Build Coastguard Worker
pow(float x,float y)101*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float pow(float x, float y) {
102*da0073e9SAndroid Build Coastguard Worker return ::powf(x, y);
103*da0073e9SAndroid Build Coastguard Worker }
pow(double x,double y)104*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double pow(double x, double y) {
105*da0073e9SAndroid Build Coastguard Worker return ::pow(x, y);
106*da0073e9SAndroid Build Coastguard Worker }
107*da0073e9SAndroid Build Coastguard Worker
sincos(float x,float * sptr,float * cptr)108*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ void sincos(float x, float* sptr, float* cptr) {
109*da0073e9SAndroid Build Coastguard Worker return ::sincosf(x, sptr, cptr);
110*da0073e9SAndroid Build Coastguard Worker }
sincos(double x,double * sptr,double * cptr)111*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ void sincos(double x, double* sptr, double* cptr) {
112*da0073e9SAndroid Build Coastguard Worker return ::sincos(x, sptr, cptr);
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker
sqrt(float x)115*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float sqrt(float x) {
116*da0073e9SAndroid Build Coastguard Worker return ::sqrtf(x);
117*da0073e9SAndroid Build Coastguard Worker }
sqrt(double x)118*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double sqrt(double x) {
119*da0073e9SAndroid Build Coastguard Worker return ::sqrt(x);
120*da0073e9SAndroid Build Coastguard Worker }
121*da0073e9SAndroid Build Coastguard Worker
rsqrt(float x)122*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float rsqrt(float x) {
123*da0073e9SAndroid Build Coastguard Worker return ::rsqrtf(x);
124*da0073e9SAndroid Build Coastguard Worker }
rsqrt(double x)125*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double rsqrt(double x) {
126*da0073e9SAndroid Build Coastguard Worker return ::rsqrt(x);
127*da0073e9SAndroid Build Coastguard Worker }
128*da0073e9SAndroid Build Coastguard Worker
tan(float x)129*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float tan(float x) {
130*da0073e9SAndroid Build Coastguard Worker return ::tanf(x);
131*da0073e9SAndroid Build Coastguard Worker }
tan(double x)132*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double tan(double x) {
133*da0073e9SAndroid Build Coastguard Worker return ::tan(x);
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker
tanh(float x)136*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float tanh(float x) {
137*da0073e9SAndroid Build Coastguard Worker return ::tanhf(x);
138*da0073e9SAndroid Build Coastguard Worker }
tanh(double x)139*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double tanh(double x) {
140*da0073e9SAndroid Build Coastguard Worker return ::tanh(x);
141*da0073e9SAndroid Build Coastguard Worker }
142*da0073e9SAndroid Build Coastguard Worker
normcdf(float x)143*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ float normcdf(float x) {
144*da0073e9SAndroid Build Coastguard Worker return ::normcdff(x);
145*da0073e9SAndroid Build Coastguard Worker }
normcdf(double x)146*da0073e9SAndroid Build Coastguard Worker __MATH_FUNCTIONS_DECL__ double normcdf(double x) {
147*da0073e9SAndroid Build Coastguard Worker return ::normcdf(x);
148*da0073e9SAndroid Build Coastguard Worker }
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda::compat
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker #endif
153