xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnarySpecialOpsKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/UnaryOps.h>
3 
4 #include <limits>
5 
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/Math.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/cuda/JitLoops.cuh>
12 #include <ATen/native/cuda/Loops.cuh>
13 #include <ATen/native/cuda/Math.cuh>
14 #include <ATen/native/cuda/jit_utils.h>
15 #include <ATen/NumericUtils.h>
16 #include <c10/core/Scalar.h>
17 #include <c10/cuda/CUDAMathCompat.h>
18 #include <c10/util/complex.h>
19 
20 namespace at::native {
21 
22 CONSTEXPR_EXCEPT_WIN_CUDA char exp2_name[] = "exp2_kernel";
exp2_kernel_cuda(TensorIteratorBase & iter)23 void exp2_kernel_cuda(TensorIteratorBase& iter) {
24   #if AT_USE_JITERATOR()
25     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
26         ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "exp2_cuda", [&]() {
27       jitted_gpu_kernel</*name=*/exp2_name,
28                         /*return_dtype=*/ scalar_t,
29                         /*common_dtype=*/ scalar_t,
30                         /*arity=*/ 1>(iter, exp2_string);
31       });
32   #else
33     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
34         ScalarType::Half, ScalarType::BFloat16,
35         iter.common_dtype(), "exp2_cuda",
36         [&]() {
37           gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
38             return exp2_impl(a);
39           });
40         });
41   #endif
42 }
43 
44 CONSTEXPR_EXCEPT_WIN_CUDA char i0_name[] = "i0";
i0_kernel_cuda(TensorIteratorBase & iter)45 void i0_kernel_cuda(TensorIteratorBase& iter) {
46   #if AT_USE_JITERATOR()
47     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
48       jitted_gpu_kernel</*name=*/i0_name,
49                         /*return_dtype=*/ scalar_t,
50                         /*common_dtype=*/ scalar_t,
51                         /*arity=*/ 1>(iter, i0_string);
52       });
53   #else
54     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0_cuda", [&]() {
55       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
56         using opmath_t = at::opmath_type<scalar_t>;
57         // implicit conversion of a to opmath_t will happen here,
58         //   but as far as TI is concerned, it's still a no-dynamic-cast kernel because lambda input is scalar_t
59         return calc_i0<opmath_t>(a);
60       });
61     });
62   #endif
63 }
64 
65 // See note [Jiterator]
66 CONSTEXPR_EXCEPT_WIN_CUDA char i0e_name[] = "calc_i0e";
i0e_kernel_cuda(TensorIteratorBase & iter)67 void i0e_kernel_cuda(TensorIteratorBase& iter) {
68   #if AT_USE_JITERATOR()
69     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() {
70       jitted_gpu_kernel</*name=*/i0e_name,
71                         /*return_dtype=*/ scalar_t,
72                         /*common_dtype=*/ scalar_t,
73                         /*arity=*/ 1>(iter, i0e_string);
74     });
75   #else
76     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "i0e_cuda", [&]() {
77       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
78         using opmath_t = at::opmath_type<scalar_t>;
79         return calc_i0e<opmath_t>(a);
80       });
81     });
82   #endif
83 }
84 
85 // See note [Jiterator]
86 
87 CONSTEXPR_EXCEPT_WIN_CUDA char i1_name[] = "i1";
i1_kernel_cuda(TensorIteratorBase & iter)88 void i1_kernel_cuda(TensorIteratorBase& iter) {
89   #if AT_USE_JITERATOR()
90     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() {
91       jitted_gpu_kernel</*name=*/i1_name,
92                         /*return_dtype=*/ scalar_t,
93                         /*common_dtype=*/ scalar_t,
94                         /*arity=*/ 1>(iter, i1_string);
95     });
96   #else
97     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1_cuda", [&]() {
98       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
99         return calc_i1(a);
100       });
101     });
102   #endif // AT_USE_JITERATOR()
103 }
104 
105 CONSTEXPR_EXCEPT_WIN_CUDA char i1e_name[] = "i1e";
i1e_kernel_cuda(TensorIteratorBase & iter)106 void i1e_kernel_cuda(TensorIteratorBase& iter) {
107   #if AT_USE_JITERATOR()
108     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cuda", [&]() {
109       jitted_gpu_kernel</*name=*/i1e_name,
110                         /*return_dtype=*/ scalar_t,
111                         /*common_dtype=*/ scalar_t,
112                         /*arity=*/ 1>(iter, i1e_string);
113     });
114   #else
115     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "i1e_cuda", [&]() {
116       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
117         return calc_i1e(a);
118       });
119     });
120   #endif
121 }
122 
123 CONSTEXPR_EXCEPT_WIN_CUDA char sigmoid_name[] = "sigmoid";
sigmoid_kernel_cuda(TensorIteratorBase & iter)124 void sigmoid_kernel_cuda(TensorIteratorBase& iter) {
125   auto common_dtype = iter.common_dtype();
126   if (at::isComplexType(common_dtype)) {
127     // only jiterate for complex-dtype
128     #if AT_USE_JITERATOR()
129       static const auto sigmoid_string = jiterator_stringify(
130         template <typename T>
131         T sigmoid(T x) {
132           return T{1} / (T{1} + std::exp(-x));
133         }
134       ); // sigmoid_string
135       AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sigmoid_cuda", [&]() {
136         jitted_gpu_kernel<
137             /*name=*/sigmoid_name,
138             /*return_dtype=*/scalar_t,
139             /*common_dtype=*/scalar_t,
140             /*arity=*/1>(iter, sigmoid_string);
141       });
142     #else
143       AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "sigmoid_cuda", [&]() {
144         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
145           using opmath_t = at::opmath_type<scalar_t>;
146           const auto one = opmath_t{1};
147           return static_cast<scalar_t>(one / (one + std::exp(-opmath_t{a})));
148         });
149       });
150     #endif
151   } else {
152     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, common_dtype, "sigmoid_cuda", [&]() {
153       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
154         using opmath_t = at::opmath_type<scalar_t>;
155         const auto one = opmath_t{1};
156         return static_cast<scalar_t>(one/(one + std::exp(-opmath_t{a})));
157       });
158     });
159   }
160 }
161 
162 CONSTEXPR_EXCEPT_WIN_CUDA char sinc_name[] = "sinc";
sinc_kernel_cuda(TensorIteratorBase & iter)163 void sinc_kernel_cuda(TensorIteratorBase& iter) {
164   #if AT_USE_JITERATOR()
165     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
166       ScalarType::Half, ScalarType::BFloat16,
167       iter.common_dtype(), "sinc_cuda",
168       [&]() {
169         jitted_gpu_kernel</*name=*/sinc_name,
170                           /*return_dtype=*/ scalar_t,
171                           /*common_dtype=*/ scalar_t,
172                           /*arity=*/ 1>(iter, sinc_string);
173       });
174   #else
175     AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
176         ScalarType::Half, ScalarType::BFloat16,
177         iter.common_dtype(), "sinc_cuda",
178         [&]() {
179           gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
180             if (a == scalar_t(0)) {
181               return scalar_t(1);
182             } else {
183               // NVCC says constexpr var is not accessible from device
184               using opmath_t = at::opmath_type<scalar_t>;
185               opmath_t product = c10::detail::pi<opmath_t>() * opmath_t{a};
186               return static_cast<scalar_t>(std::sin(product) / product);
187             }
188           });
189         });
190   #endif
191 }
192 
logit_kernel_cuda(TensorIteratorBase & iter,const Scalar & eps_scalar)193 void logit_kernel_cuda(TensorIteratorBase& iter, const Scalar& eps_scalar) {
194   AT_DISPATCH_FLOATING_TYPES_AND2(
195       at::ScalarType::Half,
196       at::ScalarType::BFloat16,
197       iter.common_dtype(),
198       "logit_cuda",
199       [&]() {
200         using T_ACC = acc_type<scalar_t, true>;
201         const T_ACC eps = eps_scalar.to<T_ACC>();
202         if (eps < T_ACC(0)) {
203           gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
204             const T_ACC x_acc = static_cast<T_ACC>(x);
205             return c10::cuda::compat::log(x_acc / (T_ACC(1) - x_acc));
206           });
207         } else {
208           const T_ACC lo = eps;
209           const T_ACC hi = T_ACC(1) - eps;
210           gpu_kernel(
211               iter, [lo, hi] GPU_LAMBDA(scalar_t x) -> scalar_t {
212                 const T_ACC x_acc = static_cast<T_ACC>(x);
213                 T_ACC z = x_acc < lo ? lo : (x_acc > hi ? hi : x_acc);
214                 return c10::cuda::compat::log(z / (T_ACC(1) - z));
215               });
216         }
217       });
218 }
219 
220 CONSTEXPR_EXCEPT_WIN_CUDA char ndtri_name[] = "ndtri";
ndtri_kernel_cuda(TensorIteratorBase & iter)221 void ndtri_kernel_cuda(TensorIteratorBase& iter) {
222   #if AT_USE_JITERATOR()
223     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() {
224       jitted_gpu_kernel</*name=*/ndtri_name,
225                         /*return_dtype=*/ scalar_t,
226                         /*common_dtype=*/ scalar_t,
227                         /*arity=*/ 1>(iter, ndtri_string);
228     });
229   #else
230     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "ndtri_cuda", [&]() {
231       gpu_kernel(
232           iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_ndtri(a); });
233       });
234   #endif
235 }
236 
237 CONSTEXPR_EXCEPT_WIN_CUDA char log_ndtr_name[] = "log_ndtr";
log_ndtr_kernel_cuda(TensorIteratorBase & iter)238 void log_ndtr_kernel_cuda(TensorIteratorBase& iter) {
239   #if AT_USE_JITERATOR()
240     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() {
241       jitted_gpu_kernel</*name=*/log_ndtr_name,
242                         /*return_dtype=*/ scalar_t,
243                         /*common_dtype=*/ scalar_t,
244                         /*arity=*/ 1>(iter, log_ndtr_string);
245     });
246   #else
247     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "log_ndtr_cuda", [&]() {
248       gpu_kernel(
249           iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_log_ndtr(a); });
250       });
251   #endif
252 }
253 
erf_kernel_cuda(TensorIteratorBase & iter)254 void erf_kernel_cuda(TensorIteratorBase& iter) {
255   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "erf_cuda", [&]() {
256     gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
257       return ::erf(a);
258     });
259   });
260 }
261 
262 CONSTEXPR_EXCEPT_WIN_CUDA char erfc_name[] = "erfc_kernel";
erfc_kernel_cuda(TensorIteratorBase & iter)263 void erfc_kernel_cuda(TensorIteratorBase& iter) {
264   #if AT_USE_JITERATOR()
265     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfc_cuda", [&]() {
266       jitted_gpu_kernel</*name=*/erfc_name,
267                         /*return_dtype=*/ scalar_t,
268                         /*common_dtype=*/ scalar_t,
269                         /*arity=*/ 1>(iter, erfc_string);
270       });
271   #else
272     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
273         iter.common_dtype(), "erfc_cuda", [&]() {
274           gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
275             return ::erfc(a);
276           });
277         });
278   #endif
279 }
280 
281 CONSTEXPR_EXCEPT_WIN_CUDA char erfinv_name[] = "erfinv_kernel";
erfinv_kernel_cuda(TensorIteratorBase & iter)282 void erfinv_kernel_cuda(TensorIteratorBase& iter) {
283   #if AT_USE_JITERATOR()
284     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "erfinv_cuda", [&]() {
285       jitted_gpu_kernel</*name=*/erfinv_name,
286                         /*return_dtype=*/ scalar_t,
287                         /*common_dtype=*/ scalar_t,
288                         /*arity=*/ 1>(iter, erfinv_string);
289       });
290   #else
291     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16,
292         iter.common_dtype(), "erfinv_cuda", [&]() {
293           gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
294             return ::erfinv(a);
295           });
296         });
297   #endif
298 }
299 
300 CONSTEXPR_EXCEPT_WIN_CUDA char erfcx_name[] = "erfcx";
erfcx_kernel_cuda(TensorIteratorBase & iter)301 void erfcx_kernel_cuda(TensorIteratorBase& iter) {
302   #if AT_USE_JITERATOR()
303     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cuda", [&]() {
304       jitted_gpu_kernel</*name=*/erfcx_name,
305                         /*return_dtype=*/ scalar_t,
306                         /*common_dtype=*/ scalar_t,
307                         /*arity=*/ 1>(iter, erfcx_string);
308     });
309   #else
310     AT_DISPATCH_FLOATING_TYPES(iter.common_dtype(), "erfcx_cuda", [&]() {
311       gpu_kernel(
312           iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return calc_erfcx(a); });
313     });
314   #endif
315 }
316 
317 CONSTEXPR_EXCEPT_WIN_CUDA char kaiser_window_name[] = "kaiser_window";
kaiser_window_kernel_cuda(TensorIteratorBase & iter,int64_t window_length,double beta_)318 void kaiser_window_kernel_cuda(TensorIteratorBase& iter, int64_t window_length, double beta_){
319   #if AT_USE_JITERATOR()
320     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
321         using opmath_t = at::opmath_type<scalar_t>;
322         const opmath_t inv_alpha = static_cast<opmath_t>(2.0 / (window_length - 1));
323         const opmath_t beta = static_cast<opmath_t>(beta_);
324         const opmath_t inv_i0_beta = 1.0 / calc_i0(beta);
325         jitted_gpu_kernel<
326             /*name=*/kaiser_window_name,
327             /*return_dtype=*/scalar_t,
328             /*common_dtype=*/scalar_t,
329             /*arity=*/1>(
330             iter,
331             kaiser_window_string,
332             /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
333             /*scalar_val=*/0,
334             /*extra_args=*/std::make_tuple(inv_alpha, beta, inv_i0_beta));
335     });
336   #else
337     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.dtype(), "kaiser_window_cuda", [&](){
338       using opmath_t = at::opmath_type<scalar_t>;
339       const opmath_t inv_alpha = static_cast<opmath_t>(2.0 / (window_length - 1));
340       const opmath_t beta = static_cast<opmath_t>(beta_);
341       const opmath_t inv_i0_beta = 1.0 / calc_i0(beta);
342       gpu_kernel(iter, [=]GPU_LAMBDA(scalar_t a) -> scalar_t {
343         opmath_t x = static_cast<opmath_t>(a) * inv_alpha - 1;
344         opmath_t y = std::max<opmath_t>(0, 1 - x * x);
345         return calc_i0(beta * ::sqrt(y)) * inv_i0_beta;
346       });
347     });
348   #endif
349 }
350 
351 CONSTEXPR_EXCEPT_WIN_CUDA char entr_name[] = "entr";
entr_kernel_cuda(TensorIteratorBase & iter)352 void entr_kernel_cuda(TensorIteratorBase& iter) {
353   #if AT_USE_JITERATOR()
354     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "entr_cuda", [&]() {
355       jitted_gpu_kernel</*name=*/entr_name,
356                         /*return_dtype=*/ scalar_t,
357                         /*common_dtype=*/ scalar_t,
358                         /*arity=*/ 1>(iter, entr_string);
359       });
360   #else
361     AT_DISPATCH_FLOATING_TYPES_AND2(
362         ScalarType::Half,
363         ScalarType::BFloat16,
364         iter.common_dtype(),
365         "entr_cuda",
366         [&]() {
367           gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t x) -> scalar_t {
368             if (at::_isnan(x)) {
369               return x;
370             } else if (x > 0) {
371               return -x * std::log(x);
372             } else if (x == 0) {
373               return 0;
374             }
375             return static_cast<scalar_t>(-INFINITY);
376           });
377         });
378   #endif
379 }
380 
381 REGISTER_DISPATCH(exp2_stub, &exp2_kernel_cuda);
382 REGISTER_DISPATCH(i0_stub, &i0_kernel_cuda);
383 REGISTER_DISPATCH(special_i0e_stub, &i0e_kernel_cuda);
384 REGISTER_DISPATCH(special_i1_stub, &i1_kernel_cuda);
385 REGISTER_DISPATCH(special_i1e_stub, &i1e_kernel_cuda);
386 REGISTER_DISPATCH(sigmoid_stub, &sigmoid_kernel_cuda);
387 REGISTER_DISPATCH(sinc_stub, &sinc_kernel_cuda);
388 REGISTER_DISPATCH(logit_stub, &logit_kernel_cuda);
389 REGISTER_DISPATCH(erf_stub, &erf_kernel_cuda);
390 REGISTER_DISPATCH(erfc_stub, &erfc_kernel_cuda);
391 REGISTER_DISPATCH(erfinv_stub, &erfinv_kernel_cuda);
392 REGISTER_DISPATCH(kaiser_window_stub, &kaiser_window_kernel_cuda);
393 REGISTER_DISPATCH(special_entr_stub, &entr_kernel_cuda);
394 REGISTER_DISPATCH(special_ndtri_stub, &ndtri_kernel_cuda);
395 REGISTER_DISPATCH(special_log_ndtr_stub, &log_ndtr_kernel_cuda);
396 REGISTER_DISPATCH(special_erfcx_stub, &erfcx_kernel_cuda);
397 
398 } // namespace at::native
399