1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <limits>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/DispatchStub.h>
8 #include <ATen/native/TensorIterator.h>
9 #include <ATen/native/cuda/Math.cuh>
10
11 namespace at::native {
12
13 // We manually overload ceil because std::ceil does not work with std::complex types.
14 template <typename scalar_t>
ceil_wrapper(scalar_t a)15 __host__ __device__ static inline scalar_t ceil_wrapper(scalar_t a) {
16 return std::ceil(a);
17 }
18
19 template<typename T>
ceil_wrapper(std::complex<T> v)20 __host__ __device__ static inline std::complex<T> ceil_wrapper(std::complex<T> v) {
21 return std::complex<T>(std::ceil(v.real()), std::ceil(v.imag()));
22 }
23
ceil_kernel_cuda(TensorIteratorBase & iter)24 void ceil_kernel_cuda(TensorIteratorBase& iter) {
25 AT_DISPATCH_FLOATING_TYPES_AND2(
26 ScalarType::Half, ScalarType::BFloat16,
27 iter.dtype(), "ceil_cuda",
28 [&]() {
29 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
30 return ceil_wrapper(a);
31 });
32 });
33 }
34
frac_kernel_cuda(TensorIteratorBase & iter)35 void frac_kernel_cuda(TensorIteratorBase& iter) {
36 AT_DISPATCH_FLOATING_TYPES_AND2(
37 ScalarType::Half, ScalarType::BFloat16,
38 iter.dtype(), "frac_cuda",
39 [&]() {
40 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
41 return a - ::trunc(a);
42 });
43 });
44 }
45
46 // We manually overload floor because std::floor does not work with std::complex types.
47 template <typename scalar_t>
floor_wrapper(scalar_t a)48 __host__ __device__ static inline scalar_t floor_wrapper(scalar_t a) {
49 return std::floor(a);
50 }
51
52 template<typename T>
floor_wrapper(std::complex<T> v)53 __host__ __device__ static inline std::complex<T> floor_wrapper(std::complex<T> v) {
54 return std::complex<T>(std::floor(v.real()), std::floor(v.imag()));
55 }
56
floor_kernel_cuda(TensorIteratorBase & iter)57 void floor_kernel_cuda(TensorIteratorBase& iter) {
58 AT_DISPATCH_FLOATING_TYPES_AND2(
59 ScalarType::Half, ScalarType::BFloat16,
60 iter.dtype(), "floor_cuda",
61 [&]() {
62 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
63 return floor_wrapper(a);
64 });
65 });
66 }
67
68 template <typename scalar_t>
reciprocal_wrapper(scalar_t a)69 __host__ __device__ static inline scalar_t reciprocal_wrapper(scalar_t a) {
70 return static_cast<scalar_t>(1)/a;
71 }
72
73 template<typename T>
reciprocal_wrapper(c10::complex<T> v)74 __host__ __device__ static inline c10::complex<T> reciprocal_wrapper(c10::complex<T> v) {
75 // Handle extreme cases for numpy compatibility
76 auto both_inf = [](T real, T imag) {
77 return (::isinf(real) && ::isinf(imag));
78 };
79
80 auto either_inf = [](T real, T imag) {
81 return ::isinf(real) || ::isinf(imag);
82 };
83
84 auto either_nan = [](T real, T imag) {
85 return ::isnan(real) || ::isnan(imag);
86 };
87
88 if (either_nan(v.real(), v.imag()) || both_inf(v.real(), v.imag())) {
89 // If either is Nan or both are infinite, return {nan, nan}
90 return {std::numeric_limits<T>::quiet_NaN(), std::numeric_limits<T>::quiet_NaN()};
91 } else if (either_inf(v.real(), v.imag())) {
92 // If either is Inf, return {0, 0}
93 return {0, 0};
94 }
95 const c10::complex<T> one = c10::complex<T>(1.0, 0);
96 return one/v;
97 }
98
reciprocal_kernel_cuda(TensorIteratorBase & iter)99 void reciprocal_kernel_cuda(TensorIteratorBase& iter) {
100 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
101 ScalarType::Half, ScalarType::BFloat16,
102 iter.common_dtype(), "reciprocal_cuda",
103 [&]() {
104 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
105 return reciprocal_wrapper(a);
106 });
107 });
108 }
109
110 // We manually overload nearbyint because std::nearbyint does not work with std::complex types and ROCm.
111 template <typename scalar_t>
nearbyint_wrapper(scalar_t a)112 __host__ __device__ static inline scalar_t nearbyint_wrapper(scalar_t a) {
113 return static_cast<scalar_t>(::nearbyintf(static_cast<float>(a)));
114 }
115
nearbyint_wrapper(double a)116 __host__ __device__ static inline double nearbyint_wrapper(double a) {
117 return ::nearbyint(a);
118 }
119
120 #pragma push
121 #pragma nv_diag_suppress 177 // Function was declared but never referenced
nearbyint_wrapper(c10::complex<float> a)122 __host__ __device__ static inline c10::complex<float> nearbyint_wrapper(c10::complex<float> a) {
123 return c10::complex<float>(::nearbyintf(static_cast<float>(a.real())), ::nearbyintf(static_cast<float>(a.imag())));
124 }
125
nearbyint_wrapper(c10::complex<double> a)126 __host__ __device__ static inline c10::complex<double> nearbyint_wrapper(c10::complex<double> a) {
127 return c10::complex<double>(::nearbyint(static_cast<double>(a.real())), ::nearbyint(static_cast<double>(a.imag())));
128 }
129 #pragma pop
130
round_kernel_cuda(TensorIteratorBase & iter)131 void round_kernel_cuda(TensorIteratorBase& iter) {
132 AT_DISPATCH_FLOATING_TYPES_AND2(
133 ScalarType::Half, ScalarType::BFloat16,
134 iter.dtype(), "round_cuda",
135 [&]() {
136 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
137 // We do not use std::round because we would like to round midway numbers to the nearest even integer.
138 return nearbyint_wrapper(a);
139 });
140 });
141 }
142
round_decimals_kernel_cuda(TensorIteratorBase & iter,int64_t decimals)143 void round_decimals_kernel_cuda(TensorIteratorBase& iter, int64_t decimals) {
144 AT_DISPATCH_FLOATING_TYPES_AND2(
145 ScalarType::Half, ScalarType::BFloat16,
146 iter.dtype(), "round_cuda",
147 [&]() {
148 bool neg_flag = false;
149 scalar_t ten_pow_decimals;
150 if (decimals < 0) {
151 decimals = -decimals;
152 neg_flag = true;
153 }
154 ten_pow_decimals = static_cast<scalar_t>(std::pow(10, decimals));
155 gpu_kernel(iter, [ten_pow_decimals, neg_flag]GPU_LAMBDA(scalar_t a) -> scalar_t {
156 return neg_flag ? std::nearbyint(a / ten_pow_decimals) * ten_pow_decimals
157 : std::nearbyint(a * ten_pow_decimals) / ten_pow_decimals;
158 });
159 });
160 }
161
162 // We manually overload trunc because std::trunc does not work with std::complex types and ROCm.
163 template <typename scalar_t>
trunc_wrapper(scalar_t a)164 __host__ __device__ static inline scalar_t trunc_wrapper(scalar_t a) {
165 return static_cast<scalar_t>(::truncf(static_cast<float>(a)));
166 }
167
trunc_wrapper(double a)168 __host__ __device__ static inline double trunc_wrapper(double a) {
169 return ::trunc(a);
170 }
171
trunc_wrapper(c10::complex<float> a)172 __host__ __device__ static inline c10::complex<float> trunc_wrapper(c10::complex<float> a) {
173 return c10::complex<float>(::truncf(static_cast<float>(a.real())), ::truncf(static_cast<float>(a.imag())));
174 }
175
trunc_wrapper(c10::complex<double> a)176 __host__ __device__ static inline c10::complex<double> trunc_wrapper(c10::complex<double> a) {
177 return c10::complex<double>(::trunc(static_cast<double>(a.real())), ::trunc(static_cast<double>(a.imag())));
178 }
179
trunc_kernel_cuda(TensorIteratorBase & iter)180 void trunc_kernel_cuda(TensorIteratorBase& iter) {
181 AT_DISPATCH_FLOATING_TYPES_AND2(
182 ScalarType::Half, ScalarType::BFloat16,
183 iter.dtype(), "trunc_cuda",
184 [&]() {
185 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
186 return trunc_wrapper(a);
187 });
188 });
189 }
190
191 REGISTER_DISPATCH(ceil_stub, &ceil_kernel_cuda);
192 REGISTER_DISPATCH(frac_stub, &frac_kernel_cuda);
193 REGISTER_DISPATCH(floor_stub, &floor_kernel_cuda);
194 REGISTER_DISPATCH(reciprocal_stub, &reciprocal_kernel_cuda);
195 REGISTER_DISPATCH(round_stub, &round_kernel_cuda);
196 REGISTER_DISPATCH(round_decimals_stub, &round_decimals_kernel_cuda);
197 REGISTER_DISPATCH(trunc_stub, &trunc_kernel_cuda);
198
199 } // namespace at::native
200