1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/AccumulateType.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/native/DispatchStub.h>
6 #include <ATen/native/TensorIterator.h>
7 #include <ATen/native/UnaryOps.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/cuda/Loops.cuh>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <limits>
12
13 namespace at::native {
14
15 #if 0 && AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char asin_name[] = "asin_impl";
17 #endif
18
asin_kernel_cuda(TensorIteratorBase & iter)19 void asin_kernel_cuda(TensorIteratorBase& iter) {
20 auto common_dtype = iter.common_dtype();
21 if (at::isComplexType(common_dtype)) {
22 // Disabled due to accuracy issues
23 #if 0 && AT_USE_JITERATOR()
24 static const auto asin_string = jiterator_stringify(
25 template <typename T> T asin_impl(T a) { return std::asin(a); });
26 AT_DISPATCH_COMPLEX_TYPES_AND(
27 kComplexHalf, common_dtype, "asin_name", [&]() {
28 jitted_gpu_kernel<
29 /*name=*/asin_name,
30 /*return_dtype=*/scalar_t,
31 /*common_dtype=*/scalar_t,
32 /*arity=*/1>(iter, asin_string);
33 });
34 #else
35 AT_DISPATCH_COMPLEX_TYPES_AND(
36 kComplexHalf, common_dtype, "asin_name", [&]() {
37 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
38 using opmath_t = at::opmath_type<scalar_t>;
39 return ::asin(static_cast<opmath_t>(a));
40 });
41 });
42 #endif
43 } else {
44 AT_DISPATCH_FLOATING_TYPES_AND2(
45 kHalf, kBFloat16, common_dtype, "asin_cuda", [&]() {
46 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
47 return ::asin(a);
48 });
49 });
50 }
51 }
52
53 REGISTER_DISPATCH(asin_stub, &asin_kernel_cuda);
54
55 } // namespace at::native
56