1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/native/UnaryOps.h>
3 #include <ATen/native/cuda/Loops.cuh>
4 #include <ATen/native/cuda/JitLoops.cuh>
5 #include <ATen/Dispatch.h>
6 #include <ATen/native/DispatchStub.h>
7 #include <ATen/native/TensorIterator.h>
8
9 namespace at::native {
10
11 template<typename scalar_t>
12 struct AbsFunctor {
operator ()at::native::AbsFunctor13 __device__ __forceinline__ scalar_t operator() (const scalar_t a) const {
14 return std::abs(a);
15 }
16 };
17
18 CONSTEXPR_EXCEPT_WIN_CUDA char abs_name[] = "abs_kernel";
abs_kernel_cuda(TensorIteratorBase & iter)19 void abs_kernel_cuda(TensorIteratorBase& iter) {
20 auto dtype = iter.dtype();
21 if (at::isComplexType(dtype)) {
22 #if AT_USE_JITERATOR()
23 static const auto abs_string = jiterator_stringify(
24 template <typename T> T abs_kernel(T x) { return std::abs(x); });
25 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "abs_cuda", [&]() {
26 jitted_gpu_kernel<
27 /*name=*/abs_name,
28 /*return_dtype=*/scalar_t,
29 /*common_dtype=*/scalar_t,
30 /*arity=*/1>(iter, abs_string);
31 });
32 #else
33 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, dtype, "abs_cuda", [&]() {
34 using opmath_t = at::opmath_type<scalar_t>;
35 gpu_kernel(iter, AbsFunctor<opmath_t>());
36 });
37 #endif
38 } else {
39 AT_DISPATCH_ALL_TYPES_AND3(
40 ScalarType::Half,
41 ScalarType::BFloat16,
42 ScalarType::Bool,
43 iter.dtype(),
44 "abs_cuda",
45 [&]() { gpu_kernel(iter, AbsFunctor<scalar_t>()); });
46 }
47 }
48
49 REGISTER_DISPATCH(abs_stub, &abs_kernel_cuda);
50
51 } // namespace at::native
52