1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <limits>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/cuda/jit_utils.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/DispatchStub.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/cuda/Math.cuh>
12
13 namespace at::native {
14
15 #if AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char log_name[] = "log_kernel";
17 #endif
18
log_kernel_cuda(TensorIteratorBase & iter)19 void log_kernel_cuda(TensorIteratorBase& iter) {
20 auto common_dtype = iter.common_dtype();
21 if (at::isComplexType(common_dtype)) {
22 #if AT_USE_JITERATOR()
23 static const auto log_string = jiterator_stringify(
24 template <typename T> T log_kernel(T x) { return std::log(x); });
25 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "log_cuda", [&]() {
26 jitted_gpu_kernel<
27 /*name=*/log_name,
28 /*return_dtype=*/scalar_t,
29 /*common_dtype=*/scalar_t,
30 /*arity=*/1>(iter, log_string);
31 });
32 #else
33 AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, iter.common_dtype(), "log_cuda", [&]() {
34 gpu_kernel(
35 iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
36 using opmath_t = at::opmath_type<scalar_t>;
37 return ::log(static_cast<opmath_t>(a));
38 });
39 });
40 #endif
41 } else {
42 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log_cuda", [&]() {
43 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
44 return ::log(a);
45 });
46 });
47 }
48 }
49
50 CONSTEXPR_EXCEPT_WIN_CUDA char log10_name[] = "log10_kernel";
log10_kernel_cuda(TensorIteratorBase & iter)51 void log10_kernel_cuda(TensorIteratorBase& iter) {
52 auto common_dtype = iter.common_dtype();
53 if (at::isComplexType(common_dtype)) {
54 #if AT_USE_JITERATOR()
55 static const auto log10_string = jiterator_stringify(
56 template <typename T> T log10_kernel(T x) { return std::log10(x); });
57 AT_DISPATCH_COMPLEX_TYPES(common_dtype, "log10_cuda", [&]() {
58 jitted_gpu_kernel<
59 /*name=*/log10_name,
60 /*return_dtype=*/scalar_t,
61 /*common_dtype=*/scalar_t,
62 /*arity=*/1>(iter, log10_string);
63 });
64 #else
65 AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "log10_cuda", [&]() {
66 gpu_kernel(
67 iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log10(a); });
68 });
69 #endif
70 } else {
71 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log10_cuda", [&]() {
72 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
73 return ::log10(a);
74 });
75 });
76 }
77 }
78
log1p_kernel_cuda(TensorIteratorBase & iter)79 void log1p_kernel_cuda(TensorIteratorBase& iter) {
80 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log1p_cuda", [&]() {
81 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
82 return ::log1p(a);
83 });
84 });
85 }
86
87 CONSTEXPR_EXCEPT_WIN_CUDA char log2_name[] = "log2_kernel";
log2_kernel_cuda(TensorIteratorBase & iter)88 void log2_kernel_cuda(TensorIteratorBase& iter) {
89 auto common_dtype = iter.common_dtype();
90 if (at::isComplexType(common_dtype)) {
91 #if AT_USE_JITERATOR()
92 static const auto log2_string = jiterator_stringify(
93 template <typename T> T log2_kernel(T x) { return std::log2(x); });
94 AT_DISPATCH_COMPLEX_TYPES(common_dtype, "log2_cuda", [&]() {
95 jitted_gpu_kernel<
96 /*name=*/log2_name,
97 /*return_dtype=*/scalar_t,
98 /*common_dtype=*/scalar_t,
99 /*arity=*/1>(iter, log2_string);
100 });
101 #else
102 AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "log2_cuda", [&]() {
103 gpu_kernel(
104 iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log2(a); });
105 });
106 #endif
107 } else {
108 AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log2_cuda", [&]() {
109 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
110 return ::log2(a);
111 });
112 });
113 }
114 }
115
116 REGISTER_DISPATCH(log_stub, &log_kernel_cuda);
117 REGISTER_DISPATCH(log10_stub, &log10_kernel_cuda);
118 REGISTER_DISPATCH(log2_stub, &log2_kernel_cuda);
119 REGISTER_DISPATCH(log1p_stub, &log1p_kernel_cuda);
120
121 } // namespace at::native
122