xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnaryLogKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <limits>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/AccumulateType.h>
6 #include <ATen/Dispatch.h>
7 #include <ATen/native/cuda/jit_utils.h>
8 #include <ATen/native/cuda/JitLoops.cuh>
9 #include <ATen/native/DispatchStub.h>
10 #include <ATen/native/TensorIterator.h>
11 #include <ATen/native/cuda/Math.cuh>
12 
13 namespace at::native {
14 
15 #if AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char log_name[] = "log_kernel";
17 #endif
18 
log_kernel_cuda(TensorIteratorBase & iter)19 void log_kernel_cuda(TensorIteratorBase& iter) {
20   auto common_dtype = iter.common_dtype();
21   if (at::isComplexType(common_dtype)) {
22 #if AT_USE_JITERATOR()
23     static const auto log_string = jiterator_stringify(
24         template <typename T> T log_kernel(T x) { return std::log(x); });
25     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, common_dtype, "log_cuda", [&]() {
26       jitted_gpu_kernel<
27           /*name=*/log_name,
28           /*return_dtype=*/scalar_t,
29           /*common_dtype=*/scalar_t,
30           /*arity=*/1>(iter, log_string);
31     });
32 #else
33     AT_DISPATCH_COMPLEX_TYPES_AND(kComplexHalf, iter.common_dtype(), "log_cuda", [&]() {
34       gpu_kernel(
35           iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t {
36             using opmath_t = at::opmath_type<scalar_t>;
37             return ::log(static_cast<opmath_t>(a));
38           });
39     });
40 #endif
41   } else {
42     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log_cuda", [&]() {
43       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
44         return ::log(a);
45       });
46     });
47   }
48 }
49 
50 CONSTEXPR_EXCEPT_WIN_CUDA char log10_name[] = "log10_kernel";
log10_kernel_cuda(TensorIteratorBase & iter)51 void log10_kernel_cuda(TensorIteratorBase& iter) {
52   auto common_dtype = iter.common_dtype();
53   if (at::isComplexType(common_dtype)) {
54 #if AT_USE_JITERATOR()
55     static const auto log10_string = jiterator_stringify(
56         template <typename T> T log10_kernel(T x) { return std::log10(x); });
57     AT_DISPATCH_COMPLEX_TYPES(common_dtype, "log10_cuda", [&]() {
58       jitted_gpu_kernel<
59           /*name=*/log10_name,
60           /*return_dtype=*/scalar_t,
61           /*common_dtype=*/scalar_t,
62           /*arity=*/1>(iter, log10_string);
63     });
64 #else
65     AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "log10_cuda", [&]() {
66       gpu_kernel(
67           iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log10(a); });
68     });
69 #endif
70   } else {
71     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log10_cuda", [&]() {
72       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
73         return ::log10(a);
74       });
75     });
76   }
77 }
78 
log1p_kernel_cuda(TensorIteratorBase & iter)79 void log1p_kernel_cuda(TensorIteratorBase& iter) {
80   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log1p_cuda", [&]() {
81     gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
82       return ::log1p(a);
83     });
84   });
85 }
86 
87 CONSTEXPR_EXCEPT_WIN_CUDA char log2_name[] = "log2_kernel";
log2_kernel_cuda(TensorIteratorBase & iter)88 void log2_kernel_cuda(TensorIteratorBase& iter) {
89   auto common_dtype = iter.common_dtype();
90   if (at::isComplexType(common_dtype)) {
91 #if AT_USE_JITERATOR()
92     static const auto log2_string = jiterator_stringify(
93         template <typename T> T log2_kernel(T x) { return std::log2(x); });
94     AT_DISPATCH_COMPLEX_TYPES(common_dtype, "log2_cuda", [&]() {
95       jitted_gpu_kernel<
96           /*name=*/log2_name,
97           /*return_dtype=*/scalar_t,
98           /*common_dtype=*/scalar_t,
99           /*arity=*/1>(iter, log2_string);
100     });
101 #else
102     AT_DISPATCH_COMPLEX_TYPES(iter.common_dtype(), "log2_cuda", [&]() {
103       gpu_kernel(
104           iter, [] GPU_LAMBDA(scalar_t a) -> scalar_t { return ::log2(a); });
105     });
106 #endif
107   } else {
108     AT_DISPATCH_FLOATING_TYPES_AND2(ScalarType::Half, ScalarType::BFloat16, iter.common_dtype(), "log2_cuda", [&]() {
109       gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
110         return ::log2(a);
111       });
112     });
113   }
114 }
115 
116 REGISTER_DISPATCH(log_stub, &log_kernel_cuda);
117 REGISTER_DISPATCH(log10_stub, &log10_kernel_cuda);
118 REGISTER_DISPATCH(log2_stub, &log2_kernel_cuda);
119 REGISTER_DISPATCH(log1p_stub, &log1p_kernel_cuda);
120 
121 } // namespace at::native
122