1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <limits>
12
13 namespace at::native {
14
15 #if 0 && AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char acosh_name[] = "acosh_impl";
17 #endif
18
acosh_kernel_cuda(TensorIteratorBase & iter)19 void acosh_kernel_cuda(TensorIteratorBase& iter) {
20 auto common_dtype = iter.common_dtype();
21 if(at::isComplexType(common_dtype)) {
22 // Disabled due to accuracy issues
23 #if 0 && AT_USE_JITERATOR()
24 static const auto acosh_string = jiterator_stringify(
25 template <typename T>
26 T acosh_impl(T a) {
27 return std::acosh(a);
28 }
29 );
30 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acosh_name", [&]() {
31 jitted_gpu_kernel<
32 /*name=*/ acosh_name,
33 /*return_dtype=*/ scalar_t,
34 /*common_dtype=*/ scalar_t,
35 /*arity=*/ 1>(iter, acosh_string);
36 });
37 #else
38 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "acosh_name", [&]() {
39 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
40 using opmath_t = at::opmath_type<scalar_t>;
41 return ::acosh(static_cast<opmath_t>(a));
42 });
43 });
44 #endif
45 } else {
46 AT_DISPATCH_FLOATING_TYPES_AND2(
47 ScalarType::Half, ScalarType::BFloat16,
48 common_dtype, "acosh_cuda",
49 [&]() {
50 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
51 return ::acosh(a);
52 });
53 });
54 }
55 }
56
57 REGISTER_DISPATCH(acosh_stub, &acosh_kernel_cuda);
58
59 } // namespace at::native
60