xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/UnaryGammaKernels.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <limits>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/JitLoops.cuh>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <ATen/native/Math.h>
12 
13 namespace at::native {
14 
15 #if AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char digamma_name[] = "digamma";
17 #endif // AT_USE_JITERATOR()
18 // See note [Jiterator]
digamma_kernel_cuda(TensorIteratorBase & iter)19 void digamma_kernel_cuda(TensorIteratorBase& iter) {
20   #if AT_USE_JITERATOR()
21     AT_DISPATCH_FLOATING_TYPES_AND2(
22       at::ScalarType::Half,
23       at::ScalarType::BFloat16,
24       iter.common_dtype(), "digamma_cuda", [&]() {
25         jitted_gpu_kernel</*name=*/digamma_name,
26                           /*return_dtype=*/ scalar_t,
27                           /*common_dtype=*/ scalar_t,
28                           /*arity=*/ 1>(iter, digamma_string);
29     });
30   #else
31     AT_DISPATCH_FLOATING_TYPES_AND2(
32       at::ScalarType::Half,
33       at::ScalarType::BFloat16,
34       iter.common_dtype(), "digamma_cuda", [&]() {
35         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
36           return calc_digamma(a);
37         });
38     });
39   #endif // AT_USE_JITERATOR()
40 }
41 
42 // See note [Jiterator]
43 CONSTEXPR_EXCEPT_WIN_CUDA char trigamma_name[] = "trigamma";
trigamma_kernel_cuda(TensorIteratorBase & iter)44 void trigamma_kernel_cuda(TensorIteratorBase& iter) {
45   #if AT_USE_JITERATOR()
46     AT_DISPATCH_FLOATING_TYPES_AND2(
47       at::ScalarType::Half,
48       at::ScalarType::BFloat16,
49       iter.common_dtype(), "trigamma_cuda", [&]() {
50         jitted_gpu_kernel</*name=*/trigamma_name,
51                           /*return_dtype=*/ scalar_t,
52                           /*common_dtype=*/ scalar_t,
53                           /*arity=*/ 1>(iter, trigamma_string);
54     });
55   #else
56     AT_DISPATCH_FLOATING_TYPES_AND2(
57       at::ScalarType::Half,
58       at::ScalarType::BFloat16,
59       iter.common_dtype(), "trigamma_cuda", [&]() {
60         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
61           return calc_trigamma(a);
62         });
63     });
64   #endif // AT_USE_JITERATOR()
65 }
66 
67 CONSTEXPR_EXCEPT_WIN_CUDA char polygamma_name[] = "polygamma";
polygamma_kernel_cuda(TensorIteratorBase & iter,int64_t n)68 void polygamma_kernel_cuda(TensorIteratorBase& iter, int64_t n) {
69   if (n == 0) {
70     digamma_kernel_cuda(iter);
71   } else if (n == 1) {
72     trigamma_kernel_cuda(iter);
73   } else {
74 #if AT_USE_JITERATOR()
75     // TODO : `unary_jitted_gpu_kernel` for cleaner UX.
76     AT_DISPATCH_FLOATING_TYPES_AND2(
77       at::ScalarType::Half,
78       at::ScalarType::BFloat16,
79         iter.common_dtype(), "polygamma_cuda", [&]() {
80           jitted_gpu_kernel<
81               /*name=*/polygamma_name,
82               /*return_dtype=*/scalar_t,
83               /*common_dtype=*/scalar_t,
84               /*arity=*/1>(
85               iter,
86               polygamma_string,
87               /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
88               /*scalar_val=*/0,
89               /*extra_args=*/std::make_tuple(n));
90         });
91 #else
92     AT_DISPATCH_FLOATING_TYPES_AND2(
93       at::ScalarType::Half,
94       at::ScalarType::BFloat16,
95         iter.common_dtype(), "polygamma_cuda", [&]() {
96           gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
97             return calc_polygamma<scalar_t, /*is_cuda=*/true>(a, static_cast<int>(n));
98           });
99         });
100 #endif // AT_USE_JITERATOR()
101   }
102 }
103 
104 CONSTEXPR_EXCEPT_WIN_CUDA char lgamma_name[] = "lgamma_kernel";
lgamma_kernel_cuda(TensorIteratorBase & iter)105 void lgamma_kernel_cuda(TensorIteratorBase& iter) {
106   #if AT_USE_JITERATOR()
107     AT_DISPATCH_FLOATING_TYPES_AND2(
108       at::ScalarType::Half,
109       at::ScalarType::BFloat16,
110       iter.common_dtype(), "lgamma_cuda", [&]() {
111         jitted_gpu_kernel</*name=*/lgamma_name,
112                           /*return_dtype=*/ scalar_t,
113                           /*common_dtype=*/ scalar_t,
114                           /*arity=*/ 1>(iter, lgamma_string);
115     });
116   #else
117     AT_DISPATCH_FLOATING_TYPES_AND2(
118       at::ScalarType::Half,
119       at::ScalarType::BFloat16,
120       iter.common_dtype(), "lgamma_cuda", [&]() {
121         gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
122           return ::lgamma(a);
123         });
124     });
125   #endif
126 }
127 
128 REGISTER_DISPATCH(digamma_stub, &digamma_kernel_cuda);
129 REGISTER_DISPATCH(polygamma_stub, &polygamma_kernel_cuda);
130 REGISTER_DISPATCH(lgamma_stub, &lgamma_kernel_cuda);
131 
132 } // namespace at::native
133