xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnaryFractionKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <limits>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/cuda/Math.cuh>
10 
11 namespace at::native {
12 
13 // We manually overload ceil because std::ceil does not work with std::complex types.
14 template <typename scalar_t>
ceil_wrapper(scalar_t a)15 __host__ __device__ static inline scalar_t ceil_wrapper(scalar_t a) {
16   return std::ceil(a);
17 }
18 
19 template<typename T>
ceil_wrapper(std::complex<T> v)20 __host__ __device__ static inline std::complex<T> ceil_wrapper(std::complex<T> v) {
21   return std::complex<T>(std::ceil(v.real()), std::ceil(v.imag()));
22 }
23 
ceil_kernel_cuda(TensorIteratorBase & iter)24 void ceil_kernel_cuda(TensorIteratorBase& iter) {
25   AT_DISPATCH_FLOATING_TYPES_AND2(
26       ScalarType::Half, ScalarType::BFloat16,
27       iter.dtype(), "ceil_cuda",
28       [&]() {
29         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
30           return ceil_wrapper(a);
31         });
32       });
33 }
34 
frac_kernel_cuda(TensorIteratorBase & iter)35 void frac_kernel_cuda(TensorIteratorBase& iter) {
36   AT_DISPATCH_FLOATING_TYPES_AND2(
37       ScalarType::Half, ScalarType::BFloat16,
38       iter.dtype(), "frac_cuda",
39       [&]() {
40         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
41           return a - ::trunc(a);
42         });
43       });
44 }
45 
46 // We manually overload floor because std::floor does not work with std::complex types.
47 template <typename scalar_t>
floor_wrapper(scalar_t a)48 __host__ __device__ static inline scalar_t floor_wrapper(scalar_t a) {
49   return std::floor(a);
50 }
51 
52 template<typename T>
floor_wrapper(std::complex<T> v)53 __host__ __device__ static inline std::complex<T> floor_wrapper(std::complex<T> v) {
54   return std::complex<T>(std::floor(v.real()), std::floor(v.imag()));
55 }
56 
floor_kernel_cuda(TensorIteratorBase & iter)57 void floor_kernel_cuda(TensorIteratorBase& iter) {
58   AT_DISPATCH_FLOATING_TYPES_AND2(
59       ScalarType::Half, ScalarType::BFloat16,
60       iter.dtype(), "floor_cuda",
61       [&]() {
62         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
63           return floor_wrapper(a);
64         });
65       });
66 }
67 
68 template <typename scalar_t>
reciprocal_wrapper(scalar_t a)69 __host__ __device__ static inline scalar_t reciprocal_wrapper(scalar_t a) {
70   return static_cast<scalar_t>(1)/a;
71 }
72 
73 template<typename T>
reciprocal_wrapper(c10::complex<T> v)74 __host__ __device__ static inline c10::complex<T> reciprocal_wrapper(c10::complex<T> v) {
75   // Handle extreme cases for numpy compatibility
76   auto both_inf = [](T real, T imag) {
77     return (::isinf(real) && ::isinf(imag));
78   };
79 
80   auto either_inf = [](T real, T imag) {
81     return ::isinf(real) || ::isinf(imag);
82   };
83 
84   auto either_nan = [](T real, T imag) {
85     return ::isnan(real) || ::isnan(imag);
86   };
87 
88   if (either_nan(v.real(), v.imag()) || both_inf(v.real(), v.imag())) {
89     // If either is Nan or both are infinite, return {nan, nan}
90     return {std::numeric_limits<T>::quiet_NaN(), std::numeric_limits<T>::quiet_NaN()};
91   } else if (either_inf(v.real(), v.imag())) {
92     // If either is Inf, return {0, 0}
93     return {0, 0};
94   }
95   const c10::complex<T> one = c10::complex<T>(1.0, 0);
96   return one/v;
97 }
98 
reciprocal_kernel_cuda(TensorIteratorBase & iter)99 void reciprocal_kernel_cuda(TensorIteratorBase& iter) {
100   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
101       ScalarType::Half, ScalarType::BFloat16,
102       iter.common_dtype(), "reciprocal_cuda",
103       [&]() {
104         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
105           return reciprocal_wrapper(a);
106         });
107       });
108 }
109 
110 // We manually overload nearbyint because std::nearbyint does not work with std::complex types and ROCm.
111 template <typename scalar_t>
nearbyint_wrapper(scalar_t a)112 __host__ __device__ static inline scalar_t nearbyint_wrapper(scalar_t a) {
113   return static_cast<scalar_t>(::nearbyintf(static_cast<float>(a)));
114 }
115 
nearbyint_wrapper(double a)116 __host__ __device__ static inline double nearbyint_wrapper(double a) {
117   return ::nearbyint(a);
118 }
119 
120 #pragma push
121 #pragma nv_diag_suppress 177   // Function was declared but never referenced
nearbyint_wrapper(c10::complex<float> a)122 __host__ __device__ static inline c10::complex<float> nearbyint_wrapper(c10::complex<float> a) {
123   return c10::complex<float>(::nearbyintf(static_cast<float>(a.real())), ::nearbyintf(static_cast<float>(a.imag())));
124 }
125 
nearbyint_wrapper(c10::complex<double> a)126 __host__ __device__ static inline c10::complex<double> nearbyint_wrapper(c10::complex<double> a) {
127   return c10::complex<double>(::nearbyint(static_cast<double>(a.real())), ::nearbyint(static_cast<double>(a.imag())));
128 }
129 #pragma pop
130 
round_kernel_cuda(TensorIteratorBase & iter)131 void round_kernel_cuda(TensorIteratorBase& iter) {
132   AT_DISPATCH_FLOATING_TYPES_AND2(
133       ScalarType::Half, ScalarType::BFloat16,
134       iter.dtype(), "round_cuda",
135       [&]() {
136         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
137           // We do not use std::round because we would like to round midway numbers to the nearest even integer.
138           return nearbyint_wrapper(a);
139         });
140       });
141 }
142 
round_decimals_kernel_cuda(TensorIteratorBase & iter,int64_t decimals)143 void round_decimals_kernel_cuda(TensorIteratorBase& iter, int64_t decimals) {
144   AT_DISPATCH_FLOATING_TYPES_AND2(
145       ScalarType::Half, ScalarType::BFloat16,
146       iter.dtype(), "round_cuda",
147       [&]() {
148         bool neg_flag = false;
149         scalar_t ten_pow_decimals;
150         if (decimals < 0) {
151           decimals = -decimals;
152           neg_flag = true;
153         }
154         ten_pow_decimals = static_cast<scalar_t>(std::pow(10, decimals));
155         gpu_kernel(iter, [ten_pow_decimals, neg_flag]GPU_LAMBDA(scalar_t a) -> scalar_t {
156           return neg_flag ? std::nearbyint(a / ten_pow_decimals) * ten_pow_decimals
157                           : std::nearbyint(a * ten_pow_decimals) / ten_pow_decimals;
158         });
159       });
160 }
161 
162 // We manually overload trunc because std::trunc does not work with std::complex types and ROCm.
163 template <typename scalar_t>
trunc_wrapper(scalar_t a)164 __host__ __device__ static inline scalar_t trunc_wrapper(scalar_t a) {
165   return static_cast<scalar_t>(::truncf(static_cast<float>(a)));
166 }
167 
trunc_wrapper(double a)168 __host__ __device__ static inline double trunc_wrapper(double a) {
169   return ::trunc(a);
170 }
171 
trunc_wrapper(c10::complex<float> a)172 __host__ __device__ static inline c10::complex<float> trunc_wrapper(c10::complex<float> a) {
173   return c10::complex<float>(::truncf(static_cast<float>(a.real())), ::truncf(static_cast<float>(a.imag())));
174 }
175 
trunc_wrapper(c10::complex<double> a)176 __host__ __device__ static inline c10::complex<double> trunc_wrapper(c10::complex<double> a) {
177   return c10::complex<double>(::trunc(static_cast<double>(a.real())), ::trunc(static_cast<double>(a.imag())));
178 }
179 
trunc_kernel_cuda(TensorIteratorBase & iter)180 void trunc_kernel_cuda(TensorIteratorBase& iter) {
181   AT_DISPATCH_FLOATING_TYPES_AND2(
182       ScalarType::Half, ScalarType::BFloat16,
183       iter.dtype(), "trunc_cuda",
184       [&]() {
185         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
186           return trunc_wrapper(a);
187         });
188       });
189 }
190 
191 REGISTER_DISPATCH(ceil_stub, &ceil_kernel_cuda);
192 REGISTER_DISPATCH(frac_stub, &frac_kernel_cuda);
193 REGISTER_DISPATCH(floor_stub, &floor_kernel_cuda);
194 REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel_cuda);
195 REGISTER_DISPATCH(round_stub, &round_kernel_cuda);
196 REGISTER_DISPATCH(round_decimals_stub, &round_decimals_kernel_cuda);
197 REGISTER_DISPATCH(trunc_stub, &trunc_kernel_cuda);
198 
199 } // namespace at::native
200