1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/UnaryOps.h>
3
4 #include <limits>
5
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/Math.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/cuda/JitLoops.cuh>
12 #include <ATen/native/cuda/Loops.cuh>
13 #include <ATen/native/cuda/Math.cuh>
14 #include <ATen/native/cuda/jit_utils.h>
15 #include <ATen/NumericUtils.h>
16 #include <c10/core/Scalar.h>
17 #include <c10/cuda/CUDAMathCompat.h>
18 #include <c10/util/complex.h>
19
20 namespace at::native {
21
22 CONSTEXPR_EXCEPT_WIN_CUDA char exp2_name[] = "exp2_kernel";
exp2_kernel_cuda(TensorIteratorBase & iter)23 void exp2_kernel_cuda(TensorIteratorBase& iter) {
24 #if AT_USE_JITERATOR()
25 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
26 ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "exp2_cuda", [&]() {
27 jitted_gpu_kernel</*name=*/exp2_name,
28 /*return_dtype=*/ scalar_t,
29 /*common_dtype=*/ scalar_t,
30 /*arity=*/ 1>(iter, exp2_string);
31 });
32 #else
33 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
34 ScalarType::Half, ScalarType::BFloat16,
35 iter.common_dtype(), "exp2_cuda",
36 [&]() {
37 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
38 return exp2_impl(a);
39 });
40 });
41 #endif
42 }
43
44 CONSTEXPR_EXCEPT_WIN_CUDA char i0_name[] = "i0";
i0_kernel_cuda(TensorIteratorBase & iter)45 void i0_kernel_cuda(TensorIteratorBase& iter) {
46 #if AT_USE_JITERATOR()
47 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
48 jitted_gpu_kernel</*name=*/i0_name,
49 /*return_dtype=*/ scalar_t,
50 /*common_dtype=*/ scalar_t,
51 /*arity=*/ 1>(iter, i0_string);
52 });
53 #else
54 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
55 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
56 using opmath_t = at::opmath_type<scalar_t>;
57 // implicit conversion of a to opmath_t will happen here,
58 // but as far as TI is concerned, it's still a no-dynamic-cast kernel because lambda input is scalar_t
59 return calc_i0<opmath_t>(a);
60 });
61 });
62 #endif
63 }
64
65 // See note [Jiterator]
66 CONSTEXPR_EXCEPT_WIN_CUDA char i0e_name[] = "calc_i0e";
i0e_kernel_cuda(TensorIteratorBase & iter)67 void i0e_kernel_cuda(TensorIteratorBase& iter) {
68 #if AT_USE_JITERATOR()
69 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() {
70 jitted_gpu_kernel</*name=*/i0e_name,
71 /*return_dtype=*/ scalar_t,
72 /*common_dtype=*/ scalar_t,
73 /*arity=*/ 1>(iter, i0e_string);
74 });
75 #else
76 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() {
77 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
78 using opmath_t = at::opmath_type<scalar_t>;
79 return calc_i0e<opmath_t>(a);
80 });
81 });
82 #endif
83 }
84
85 // See note [Jiterator]
86
87 CONSTEXPR_EXCEPT_WIN_CUDA char i1_name[] = "i1";
i1_kernel_cuda(TensorIteratorBase & iter)88 void i1_kernel_cuda(TensorIteratorBase& iter) {
89 #if AT_USE_JITERATOR()
90 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() {
91 jitted_gpu_kernel</*name=*/i1_name,
92 /*return_dtype=*/ scalar_t,
93 /*common_dtype=*/ scalar_t,
94 /*arity=*/ 1>(iter, i1_string);
95 });
96 #else
97 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() {
98 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
99 return calc_i1(a);
100 });
101 });
102 #endif // AT_USE_JITERATOR()
103 }
104
105 CONSTEXPR_EXCEPT_WIN_CUDA char i1e_name[] = "i1e";
i1e_kernel_cuda(TensorIteratorBase & iter)106 void i1e_kernel_cuda(TensorIteratorBase& iter) {
107 #if AT_USE_JITERATOR()
108 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cuda", [&]() {
109 jitted_gpu_kernel</*name=*/i1e_name,
110 /*return_dtype=*/ scalar_t,
111 /*common_dtype=*/ scalar_t,
112 /*arity=*/ 1>(iter, i1e_string);
113 });
114 #else
115 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cuda", [&]() {
116 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
117 return calc_i1e(a);
118 });
119 });
120 #endif
121 }
122
123 CONSTEXPR_EXCEPT_WIN_CUDA char sigmoid_name[] = "sigmoid";
sigmoid_kernel_cuda(TensorIteratorBase & iter)124 void sigmoid_kernel_cuda(TensorIteratorBase& iter) {
125 auto common_dtype = iter.common_dtype();
126 if (at::isComplexType(common_dtype)) {
127 // only jiterate for complex-dtype
128 #if AT_USE_JITERATOR()
129 static const auto sigmoid_string = jiterator_stringify(
130 template <typename T>
131 T sigmoid(T x) {
132 return T{1} / (T{1} + std::exp(-x));
133 }
134 ); // sigmoid_string
135 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sigmoid_cuda", [&]() {
136 jitted_gpu_kernel<
137 /*name=*/sigmoid_name,
138 /*return_dtype=*/scalar_t,
139 /*common_dtype=*/scalar_t,
140 /*arity=*/1>(iter, sigmoid_string);
141 });
142 #else
143 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sigmoid_cuda", [&]() {
144 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
145 using opmath_t = at::opmath_type<scalar_t>;
146 const auto one = opmath_t{1};
147 return static_cast<scalar_t>(one / (one + std::exp(-opmath_t{a})));
148 });
149 });
150 #endif
151 } else {
152 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, common_dtype, "sigmoid_cuda", [&]() {
153 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
154 using opmath_t = at::opmath_type<scalar_t>;
155 const auto one = opmath_t{1};
156 return static_cast<scalar_t>(one/(one + std::exp(-opmath_t{a})));
157 });
158 });
159 }
160 }
161
162 CONSTEXPR_EXCEPT_WIN_CUDA char sinc_name[] = "sinc";
sinc_kernel_cuda(TensorIteratorBase & iter)163 void sinc_kernel_cuda(TensorIteratorBase& iter) {
164 #if AT_USE_JITERATOR()
165 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
166 ScalarType::Half, ScalarType::BFloat16,
167 iter.common_dtype(), "sinc_cuda",
168 [&]() {
169 jitted_gpu_kernel</*name=*/sinc_name,
170 /*return_dtype=*/ scalar_t,
171 /*common_dtype=*/ scalar_t,
172 /*arity=*/ 1>(iter, sinc_string);
173 });
174 #else
175 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
176 ScalarType::Half, ScalarType::BFloat16,
177 iter.common_dtype(), "sinc_cuda",
178 [&]() {
179 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
180 if (a == scalar_t(0)) {
181 return scalar_t(1);
182 } else {
183 // NVCC says constexpr var is not accessible from device
184 using opmath_t = at::opmath_type<scalar_t>;
185 opmath_t product = c10::detail::pi<opmath_t>() * opmath_t{a};
186 return static_cast<scalar_t>(std::sin(product) / product);
187 }
188 });
189 });
190 #endif
191 }
192
logit_kernel_cuda(TensorIteratorBase & iter,const Scalar & eps_scalar)193 void logit_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) {
194 AT_DISPATCH_FLOATING_TYPES_AND2(
195 at::ScalarType::Half,
196 at::ScalarType::BFloat16,
197 iter.common_dtype(),
198 "logit_cuda",
199 [&]() {
200 using T_ACC = acc_type<scalar_t, true>;
201 const T_ACC eps = eps_scalar.to<T_ACC>();
202 if (eps < T_ACC(0)) {
203 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
204 const T_ACC x_acc = static_cast<T_ACC>(x);
205 return c10::cuda::compat::log(x_acc / (T_ACC(1) - x_acc));
206 });
207 } else {
208 const T_ACC lo = eps;
209 const T_ACC hi = T_ACC(1) - eps;
210 gpu_kernel(
211 iter, [lo, hi] GPU_LAMBDA(scalar_t x) -> scalar_t {
212 const T_ACC x_acc = static_cast<T_ACC>(x);
213 T_ACC z = x_acc < lo ? lo : (x_acc > hi ? hi : x_acc);
214 return c10::cuda::compat::log(z / (T_ACC(1) - z));
215 });
216 }
217 });
218 }
219
220 CONSTEXPR_EXCEPT_WIN_CUDA char ndtri_name[] = "ndtri";
ndtri_kernel_cuda(TensorIteratorBase & iter)221 void ndtri_kernel_cuda(TensorIteratorBase& iter) {
222 #if AT_USE_JITERATOR()
223 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() {
224 jitted_gpu_kernel</*name=*/ndtri_name,
225 /*return_dtype=*/ scalar_t,
226 /*common_dtype=*/ scalar_t,
227 /*arity=*/ 1>(iter, ndtri_string);
228 });
229 #else
230 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() {
231 gpu_kernel(
232 iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_ndtri(a); });
233 });
234 #endif
235 }
236
237 CONSTEXPR_EXCEPT_WIN_CUDA char log_ndtr_name[] = "log_ndtr";
log_ndtr_kernel_cuda(TensorIteratorBase & iter)238 void log_ndtr_kernel_cuda(TensorIteratorBase& iter) {
239 #if AT_USE_JITERATOR()
240 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() {
241 jitted_gpu_kernel</*name=*/log_ndtr_name,
242 /*return_dtype=*/ scalar_t,
243 /*common_dtype=*/ scalar_t,
244 /*arity=*/ 1>(iter, log_ndtr_string);
245 });
246 #else
247 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() {
248 gpu_kernel(
249 iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_log_ndtr(a); });
250 });
251 #endif
252 }
253
erf_kernel_cuda(TensorIteratorBase & iter)254 void erf_kernel_cuda(TensorIteratorBase& iter) {
255 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
256 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
257 return ::erf(a);
258 });
259 });
260 }
261
262 CONSTEXPR_EXCEPT_WIN_CUDA char erfc_name[] = "erfc_kernel";
erfc_kernel_cuda(TensorIteratorBase & iter)263 void erfc_kernel_cuda(TensorIteratorBase& iter) {
264 #if AT_USE_JITERATOR()
265 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfc_cuda", [&]() {
266 jitted_gpu_kernel</*name=*/erfc_name,
267 /*return_dtype=*/ scalar_t,
268 /*common_dtype=*/ scalar_t,
269 /*arity=*/ 1>(iter, erfc_string);
270 });
271 #else
272 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
273 iter.common_dtype(), "erfc_cuda", [&]() {
274 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
275 return ::erfc(a);
276 });
277 });
278 #endif
279 }
280
281 CONSTEXPR_EXCEPT_WIN_CUDA char erfinv_name[] = "erfinv_kernel";
erfinv_kernel_cuda(TensorIteratorBase & iter)282 void erfinv_kernel_cuda(TensorIteratorBase& iter) {
283 #if AT_USE_JITERATOR()
284 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfinv_cuda", [&]() {
285 jitted_gpu_kernel</*name=*/erfinv_name,
286 /*return_dtype=*/ scalar_t,
287 /*common_dtype=*/ scalar_t,
288 /*arity=*/ 1>(iter, erfinv_string);
289 });
290 #else
291 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
292 iter.common_dtype(), "erfinv_cuda", [&]() {
293 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
294 return ::erfinv(a);
295 });
296 });
297 #endif
298 }
299
300 CONSTEXPR_EXCEPT_WIN_CUDA char erfcx_name[] = "erfcx";
erfcx_kernel_cuda(TensorIteratorBase & iter)301 void erfcx_kernel_cuda(TensorIteratorBase& iter) {
302 #if AT_USE_JITERATOR()
303 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cuda", [&]() {
304 jitted_gpu_kernel</*name=*/erfcx_name,
305 /*return_dtype=*/ scalar_t,
306 /*common_dtype=*/ scalar_t,
307 /*arity=*/ 1>(iter, erfcx_string);
308 });
309 #else
310 AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cuda", [&]() {
311 gpu_kernel(
312 iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_erfcx(a); });
313 });
314 #endif
315 }
316
317 CONSTEXPR_EXCEPT_WIN_CUDA char kaiser_window_name[] = "kaiser_window";
kaiser_window_kernel_cuda(TensorIteratorBase & iter,int64_t window_length,double beta_)318 void kaiser_window_kernel_cuda(TensorIteratorBase& iter, int64_t window_length, double beta_){
319 #if AT_USE_JITERATOR()
320 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
321 using opmath_t = at::opmath_type<scalar_t>;
322 const opmath_t inv_alpha = static_cast<opmath_t>(2.0 / (window_length - 1));
323 const opmath_t beta = static_cast<opmath_t>(beta_);
324 const opmath_t inv_i0_beta = 1.0 / calc_i0(beta);
325 jitted_gpu_kernel<
326 /*name=*/kaiser_window_name,
327 /*return_dtype=*/scalar_t,
328 /*common_dtype=*/scalar_t,
329 /*arity=*/1>(
330 iter,
331 kaiser_window_string,
332 /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
333 /*scalar_val=*/0,
334 /*extra_args=*/std::make_tuple(inv_alpha, beta, inv_i0_beta));
335 });
336 #else
337 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
338 using opmath_t = at::opmath_type<scalar_t>;
339 const opmath_t inv_alpha = static_cast<opmath_t>(2.0 / (window_length - 1));
340 const opmath_t beta = static_cast<opmath_t>(beta_);
341 const opmath_t inv_i0_beta = 1.0 / calc_i0(beta);
342 gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
343 opmath_t x = static_cast<opmath_t>(a) * inv_alpha - 1;
344 opmath_t y = std::max<opmath_t>(0, 1 - x * x);
345 return calc_i0(beta * ::sqrt(y)) * inv_i0_beta;
346 });
347 });
348 #endif
349 }
350
351 CONSTEXPR_EXCEPT_WIN_CUDA char entr_name[] = "entr";
entr_kernel_cuda(TensorIteratorBase & iter)352 void entr_kernel_cuda(TensorIteratorBase& iter) {
353 #if AT_USE_JITERATOR()
354 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "entr_cuda", [&]() {
355 jitted_gpu_kernel</*name=*/entr_name,
356 /*return_dtype=*/ scalar_t,
357 /*common_dtype=*/ scalar_t,
358 /*arity=*/ 1>(iter, entr_string);
359 });
360 #else
361 AT_DISPATCH_FLOATING_TYPES_AND2(
362 ScalarType::Half,
363 ScalarType::BFloat16,
364 iter.common_dtype(),
365 "entr_cuda",
366 [&]() {
367 gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
368 if (at::_isnan(x)) {
369 return x;
370 } else if (x > 0) {
371 return -x * std::log(x);
372 } else if (x == 0) {
373 return 0;
374 }
375 return static_cast<scalar_t>(-INFINITY);
376 });
377 });
378 #endif
379 }
380
381 REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
382 REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda);
383 REGISTER_DISPATCH(special_i0e_stub, &i0e_kernel_cuda);
384 REGISTER_DISPATCH(special_i1_stub, &i1_kernel_cuda);
385 REGISTER_DISPATCH(special_i1e_stub, &i1e_kernel_cuda);
386 REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda);
387 REGISTER_DISPATCH(sinc_stub, &sinc_kernel_cuda);
388 REGISTER_DISPATCH(logit_stub, &logit_kernel_cuda);
389 REGISTER_DISPATCH(erf_stub, &erf_kernel_cuda);
390 REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
391 REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
392 REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
393 REGISTER_DISPATCH(special_entr_stub, &entr_kernel_cuda);
394 REGISTER_DISPATCH(special_ndtri_stub, &ndtri_kernel_cuda);
395 REGISTER_DISPATCH(special_log_ndtr_stub, &log_ndtr_kernel_cuda);
396 REGISTER_DISPATCH(special_erfcx_stub, &erfcx_kernel_cuda);
397
398 } // namespace at::native
399