1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <limits>
12
13 namespace at::native {
14
15 #if 0 && AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char acos_name[] = "acos_impl";
17 #endif
acos_kernel_cuda(TensorIteratorBase & iter)18 void acos_kernel_cuda(TensorIteratorBase& iter) {
19 auto common_dtype = iter.common_dtype();
20 if (at::isComplexType(common_dtype)) {
21 // Disabled due to accuracy issues
22 #if 0 && AT_USE_JITERATOR()
23 static const auto acos_string = jiterator_stringify(
24 template <typename T> T acos_impl(T a) { return std::acos(a); });
25 AT_DISPATCH_COMPLEX_TYPES_AND(
26 kComplexHalf, common_dtype, "acos_name", [&]() {
27 jitted_gpu_kernel<
28 /*name=*/acos_name,
29 /*return_dtype=*/scalar_t,
30 /*common_dtype=*/scalar_t,
31 /*arity=*/1>(iter, acos_string);
32 });
33 #else
34 AT_DISPATCH_COMPLEX_TYPES_AND(
35 kComplexHalf, common_dtype, "acos_name", [&]() {
36 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
37 using opmath_t = at::opmath_type<scalar_t>;
38 return ::acos(static_cast<opmath_t>(a));
39 });
40 });
41 #endif
42 } else {
43 AT_DISPATCH_FLOATING_TYPES_AND2(
44 ScalarType::Half,
45 ScalarType::BFloat16,
46 common_dtype,
47 "acos_cuda",
48 [&]() {
49 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
50 return ::acos(a);
51 });
52 });
53 }
54 }
55
56 REGISTER_DISPATCH(acos_stub, &acos_kernel_cuda);
57
58 } // namespace at::native
59