xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/AbsKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/UnaryOps.h>
3 #include <ATen/native/cuda/Loops.cuh>
4 #include <ATen/native/cuda/JitLoops.cuh>
5 #include <ATen/Dispatch.h>
6 #include <ATen/native/DispatchStub.h>
7 #include <ATen/native/TensorIterator.h>
8 
9 namespace at::native {
10 
11 template<typename scalar_t>
12 struct AbsFunctor {
operator ()at::native::AbsFunctor13   __device__ __forceinline__ scalar_t operator() (const scalar_t a) const {
14     return std::abs(a);
15   }
16 };
17 
18 CONSTEXPR_EXCEPT_WIN_CUDA char abs_name[] = "abs_kernel";
abs_kernel_cuda(TensorIteratorBase & iter)19 void abs_kernel_cuda(TensorIteratorBase& iter) {
20   auto dtype = iter.dtype();
21   if (at::isComplexType(dtype)) {
22 #if AT_USE_JITERATOR()
23     static const auto abs_string = jiterator_stringify(
24         template <typename T> T abs_kernel(T x) { return std::abs(x); });
25     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "abs_cuda", [&]() {
26       jitted_gpu_kernel<
27           /*name=*/abs_name,
28           /*return_dtype=*/scalar_t,
29           /*common_dtype=*/scalar_t,
30           /*arity=*/1>(iter, abs_string);
31     });
32 #else
33     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "abs_cuda", [&]() {
34       using opmath_t = at::opmath_type<scalar_t>;
35       gpu_kernel(iter, AbsFunctor<opmath_t>());
36     });
37 #endif
38   } else {
39     AT_DISPATCH_ALL_TYPES_AND3(
40         ScalarType::Half,
41         ScalarType::BFloat16,
42         ScalarType::Bool,
43         iter.dtype(),
44         "abs_cuda",
45         [&]() { gpu_kernel(iter, AbsFunctor<scalar_t>()); });
46   }
47 }
48 
49   REGISTER_DISPATCH(abs_stub, &abs_kernel_cuda);
50 
51 } // namespace at::native
52