1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/UnaryOps.h>
3
4 #include <limits>
5
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/Math.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/cuda/jit_utils.h>
12 #include <ATen/native/cuda/JitLoops.cuh>
13 #include <ATen/native/cuda/Loops.cuh>
14 #include <ATen/native/cuda/Math.cuh>
15 #include <ATen/NumericUtils.h>
16 #include <ATen/OpMathType.h>
17 #include <c10/cuda/CUDAMathCompat.h>
18 #include <c10/core/Scalar.h>
19 #include <c10/util/complex.h>
20
21 namespace at::native {
22
bitwise_not_kernel_cuda(TensorIteratorBase & iter)23 void bitwise_not_kernel_cuda(TensorIteratorBase& iter) {
24 if (iter.dtype() == ScalarType::Bool) {
25 gpu_kernel(iter, []GPU_LAMBDA(bool a) {
26 return !a;
27 });
28 } else {
29 AT_DISPATCH_INTEGRAL_TYPES(iter.dtype(), "bitwise_not_cuda", [&]() {
30 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
31 return ~a;
32 });
33 });
34 }
35 }
36
37 CONSTEXPR_EXCEPT_WIN_CUDA char exp_name[] = "exp_kernel";
exp_kernel_cuda(TensorIteratorBase & iter)38 void exp_kernel_cuda(TensorIteratorBase& iter) {
39 auto common_dtype = iter.common_dtype();
40 if (at::isComplexType(common_dtype)) {
41 #if AT_USE_JITERATOR()
42 static const auto exp_string = jiterator_stringify(
43 template <typename T>
44 T exp_kernel(T x) {
45 return std::exp(x);
46 }); // exp_string
47 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "exp_cuda", [&]() {
48 jitted_gpu_kernel<
49 /*name=*/exp_name,
50 /*return_dtype=*/scalar_t,
51 /*common_dtype=*/scalar_t,
52 /*arity=*/1>(iter, exp_string);
53 });
54 #else
55 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "exp_cuda", [&]() {
56 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
57 using opmath_t = at::opmath_type<scalar_t>;
58 return std::exp(static_cast<opmath_t>(a));
59 });
60 });
61 #endif
62 } else {
63 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, common_dtype, "exp_cuda", [&]() {
64 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
65 return std::exp(a);
66 });
67 });
68 }
69 }
70
expm1_kernel_cuda(TensorIteratorBase & iter)71 void expm1_kernel_cuda(TensorIteratorBase& iter) {
72 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
73 ScalarType::BFloat16, ScalarType::Half,
74 iter.common_dtype(), "expm1_cuda",
75 [&]() {
76 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
77 return ::expm1(a);
78 });
79 });
80 }
81
82 // We manually overload rsqrt because std::rsqrt does not work with complex types.
83 template<typename scalar_t>
rsqrt_wrapper(scalar_t v)84 C10_HOST_DEVICE static inline scalar_t rsqrt_wrapper(scalar_t v) {
85 return ::rsqrt(v);
86 }
87
88 template<typename T>
rsqrt_wrapper(c10::complex<T> v)89 C10_HOST_DEVICE static inline c10::complex<T> rsqrt_wrapper(c10::complex<T> v) {
90 const c10::complex<T> one = c10::complex<T>(1.0, 0);
91 // std::sqrt for c10::complex is overloaded in c10/util/complex_math.h
92 return one / ::sqrt(v);
93 }
94
95 CONSTEXPR_EXCEPT_WIN_CUDA char rsqrt_name[] = "rsqrt_kernel";
rsqrt_kernel_cuda(TensorIteratorBase & iter)96 void rsqrt_kernel_cuda(TensorIteratorBase& iter) {
97 auto common_dtype = iter.common_dtype();
98 if (at::isComplexType(common_dtype)) {
99 #if AT_USE_JITERATOR()
100 static const auto rsqrt_string = jiterator_stringify(
101 template <typename T>
102 T rsqrt_kernel(T x) {
103 const T one = T{1};
104 return one / std::sqrt(x);
105 }); // rsqrt_string
106 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "rsqrt_cuda", [&]() {
107 jitted_gpu_kernel<
108 /*name=*/rsqrt_name,
109 /*return_dtype=*/scalar_t,
110 /*common_dtype=*/scalar_t,
111 /*arity=*/1>(iter, rsqrt_string);
112 });
113 #else
114 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "rsqrt_cuda", [&]() {
115 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
116 using opmath_t = at::opmath_type<scalar_t>;
117 return rsqrt_wrapper(static_cast<opmath_t>(a));
118 });
119 });
120 #endif
121 } else {
122 AT_DISPATCH_FLOATING_TYPES_AND2(
123 ScalarType::BFloat16, ScalarType::Half,
124 iter.common_dtype(), "rsqrt_cuda",
125 [&]() {
126 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
127 // In CUDA, ::rsqrt is overloaded for float and at::Half here is implicitly cast to float.
128 return rsqrt_wrapper(a);
129 });
130 });
131 }
132 }
133
134 CONSTEXPR_EXCEPT_WIN_CUDA char sqrt_name[] = "sqrt_kernel";
sqrt_kernel_cuda(TensorIteratorBase & iter)135 void sqrt_kernel_cuda(TensorIteratorBase& iter) {
136 auto common_dtype = iter.common_dtype();
137 if (at::isComplexType(common_dtype)) {
138 #if AT_USE_JITERATOR()
139 static const auto sqrt_string = jiterator_stringify(
140 template <typename T>
141 T sqrt_kernel(T x) {
142 return std::sqrt(x);
143 }); // sqrt_string
144 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sqrt_cuda", [&]() {
145 jitted_gpu_kernel<
146 /*name=*/sqrt_name,
147 /*return_dtype=*/scalar_t,
148 /*common_dtype=*/scalar_t,
149 /*arity=*/1>(iter, sqrt_string);
150 });
151 #else
152 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sqrt_cuda", [&]() {
153 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
154 using opmath_t = at::opmath_type<scalar_t>;
155 return ::sqrt(static_cast<opmath_t>(a));
156 });
157 });
158 #endif
159 } else {
160 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, common_dtype, "sqrt_cuda", [&]() {
161 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
162 return std::sqrt(a);
163 });
164 });
165 }
166 }
167
clamp_kernel_cuda(TensorIteratorBase & iter,const Scalar & min_value,const Scalar & max_value)168 void clamp_kernel_cuda(TensorIteratorBase& iter, const Scalar& min_value, const Scalar& max_value) {
169 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_cuda", [&]() {
170 auto lower = min_value.to<scalar_t>();
171 auto upper = max_value.to<scalar_t>();
172 gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
173 // Propagate nan, which doesn't propagate automatically for ROCm
174 if (_isnan(v)) {
175 return v;
176 } else {
177 return ::min(::max(v, lower), upper);
178 }
179 });
180 });
181 }
182
clamp_min_kernel_cuda(TensorIteratorBase & iter,const Scalar & min_value)183 void clamp_min_kernel_cuda(TensorIteratorBase& iter, const Scalar& min_value) {
184 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_min_cuda", [&]() {
185 auto lower = min_value.to<scalar_t>();
186 gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
187 // Propagate nan, which doesn't propagate automatically for ROCm
188 if (_isnan(v)) {
189 return v;
190 } else {
191 return ::max(v, lower);
192 }
193 });
194 });
195 }
196
clamp_max_kernel_cuda(TensorIteratorBase & iter,const Scalar & max_value)197 void clamp_max_kernel_cuda(TensorIteratorBase& iter, const Scalar& max_value) {
198 AT_DISPATCH_ALL_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "clamp_max_cuda", [&]() {
199 auto upper = max_value.to<scalar_t>();
200 gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t v) -> scalar_t {
201 // Propagate nan, which doesn't propagate automatically for ROCm
202 if (_isnan(v)) {
203 return v;
204 } else {
205 return ::min(v, upper);
206 }
207 });
208 });
209 }
210
211 template<typename scalar_t>
_nan_to_num_replace(scalar_t a,scalar_t nan_replacement,scalar_t pos_inf_replacement,scalar_t neg_inf_replacement)212 C10_HOST_DEVICE static inline scalar_t _nan_to_num_replace(scalar_t a, scalar_t nan_replacement, scalar_t pos_inf_replacement, scalar_t neg_inf_replacement) {
213 return at::_isnan(a)
214 ? nan_replacement
215 : (a == std::numeric_limits<scalar_t>::infinity()
216 ? pos_inf_replacement
217 : (a == -std::numeric_limits<scalar_t>::infinity()
218 ? neg_inf_replacement
219 : a));
220 }
221
nan_to_num_kernel_cuda(TensorIteratorBase & iter,std::optional<double> nan,std::optional<double> pos_inf,std::optional<double> neg_inf)222 void nan_to_num_kernel_cuda(
223 TensorIteratorBase& iter,
224 std::optional<double> nan,
225 std::optional<double> pos_inf,
226 std::optional<double> neg_inf) {
227 if (isComplexType(iter.dtype())) {
228 AT_DISPATCH_COMPLEX_TYPES(iter.dtype(), "nan_to_num", [&]() {
229 using value_t = scalar_t::value_type;
230 value_t nan_replacement = static_cast<value_t>(nan.value_or(0.));
231 value_t pos_inf_replacement = pos_inf.has_value()
232 ? static_cast<value_t>(pos_inf.value())
233 : std::numeric_limits<value_t>::max();
234 value_t neg_inf_replacement = neg_inf.has_value()
235 ? static_cast<value_t>(neg_inf.value())
236 : std::numeric_limits<value_t>::lowest();
237
238 gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
239 value_t res_real = _nan_to_num_replace(
240 a.real(), nan_replacement, pos_inf_replacement, neg_inf_replacement);
241 value_t res_imag = _nan_to_num_replace(
242 a.imag(), nan_replacement, pos_inf_replacement, neg_inf_replacement);
243 return scalar_t(res_real, res_imag);
244 });
245 });
246 } else {
247 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, iter.dtype(), "nan_to_num_cuda", [&]() {
248 scalar_t nan_replacement = static_cast<scalar_t>(nan.value_or(0.));
249 scalar_t pos_inf_replacement = pos_inf.has_value()
250 ? static_cast<scalar_t>(pos_inf.value())
251 : std::numeric_limits<scalar_t>::max();
252 scalar_t neg_inf_replacement = neg_inf.has_value()
253 ? static_cast<scalar_t>(neg_inf.value())
254 : std::numeric_limits<scalar_t>::lowest();
255
256 gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
257 return _nan_to_num_replace(
258 a, nan_replacement, pos_inf_replacement, neg_inf_replacement);
259 });
260 });
261 }
262 }
263
frexp_kernel_cuda(TensorIteratorBase & iter)264 void frexp_kernel_cuda(TensorIteratorBase& iter) {
265 AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16,
266 // The iter.dtype() here is the dtype of mantissa output.
267 // It's a floating point type and must be the same as the input's dtype.
268 iter.dtype(),
269 "frexp_cuda", [&]() {
270 gpu_kernel_multiple_outputs(iter, [=] GPU_LAMBDA (scalar_t a) -> thrust::tuple<scalar_t, int32_t> {
271 int32_t exponent;
272 scalar_t mantissa = std::frexp(a, &exponent);
273 return {mantissa, exponent};
274 });
275 });
276 }
277
278 REGISTER_DISPATCH(bitwise_not_stub, &bitwise_not_kernel_cuda);
279 REGISTER_DISPATCH(exp_stub, &exp_kernel_cuda);
280 REGISTER_DISPATCH(expm1_stub, &expm1_kernel_cuda);
281 REGISTER_DISPATCH(rsqrt_stub, &rsqrt_kernel_cuda);
282 REGISTER_DISPATCH(sqrt_stub, &sqrt_kernel_cuda);
283 REGISTER_DISPATCH(nan_to_num_stub, &nan_to_num_kernel_cuda);
284 REGISTER_DISPATCH(frexp_stub, &frexp_kernel_cuda);
285
286 } // namespace at::native
287