xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnaryGeometricAsinhKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <limits>
12 
13 namespace at::native {
14 
15 #if 0 && AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char asinh_name[] = "asinh_impl";
17 #endif
18 
asinh_kernel_cuda(TensorIteratorBase & iter)19 void asinh_kernel_cuda(TensorIteratorBase& iter) {
20   auto common_dtype = iter.common_dtype();
21   if (at::isComplexType(common_dtype)) {
22     // Disabled due to accuracy issues
23 #if 0 && AT_USE_JITERATOR()
24     static const auto asinh_string = jiterator_stringify(
25         template <typename T> T asinh_impl(T a) { return std::asinh(a); });
26     AT_DISPATCH_COMPLEX_TYPES_AND(
27         kComplexHalf, common_dtype, "asinh_name", [&]() {
28           jitted_gpu_kernel<
29               /*name=*/asinh_name,
30               /*return_dtype=*/scalar_t,
31               /*common_dtype=*/scalar_t,
32               /*arity=*/1>(iter, asinh_string);
33         });
34 #else
35     AT_DISPATCH_COMPLEX_TYPES_AND(
36         kComplexHalf, common_dtype, "asinh_name", [&]() {
37           gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
38             using opmath_t = at::opmath_type<scalar_t>;
39             return ::asinh(static_cast<opmath_t>(a));
40           });
41         });
42 #endif
43   } else {
44     AT_DISPATCH_FLOATING_TYPES_AND2(
45         ScalarType::Half,
46         ScalarType::BFloat16,
47         common_dtype,
48         "asinh_cuda",
49         [&]() {
50           gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
51             return ::asinh(a);
52           });
53         });
54   }
55 }
56 
57 REGISTER_DISPATCH(asinh_stub, &asinh_kernel_cuda);
58 
59 } // namespace at::native
60