1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <limits>
3 #include <ATen/native/UnaryOps.h>
4 #include <ATen/native/cuda/JitLoops.cuh>
5 #include <ATen/native/cuda/Loops.cuh>
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/native/DispatchStub.h>
9 #include <ATen/native/TensorIterator.h>
10 #include <ATen/native/cuda/Math.cuh>
11 #include <ATen/native/Math.h>
12
13 namespace at::native {
14
15 #if AT_USE_JITERATOR()
16 CONSTEXPR_EXCEPT_WIN_CUDA char digamma_name[] = "digamma";
17 #endif // AT_USE_JITERATOR()
18 // See note [Jiterator]
digamma_kernel_cuda(TensorIteratorBase & iter)19 void digamma_kernel_cuda(TensorIteratorBase& iter) {
20 #if AT_USE_JITERATOR()
21 AT_DISPATCH_FLOATING_TYPES_AND2(
22 at::ScalarType::Half,
23 at::ScalarType::BFloat16,
24 iter.common_dtype(), "digamma_cuda", [&]() {
25 jitted_gpu_kernel</*name=*/digamma_name,
26 /*return_dtype=*/ scalar_t,
27 /*common_dtype=*/ scalar_t,
28 /*arity=*/ 1>(iter, digamma_string);
29 });
30 #else
31 AT_DISPATCH_FLOATING_TYPES_AND2(
32 at::ScalarType::Half,
33 at::ScalarType::BFloat16,
34 iter.common_dtype(), "digamma_cuda", [&]() {
35 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
36 return calc_digamma(a);
37 });
38 });
39 #endif // AT_USE_JITERATOR()
40 }
41
42 // See note [Jiterator]
43 CONSTEXPR_EXCEPT_WIN_CUDA char trigamma_name[] = "trigamma";
trigamma_kernel_cuda(TensorIteratorBase & iter)44 void trigamma_kernel_cuda(TensorIteratorBase& iter) {
45 #if AT_USE_JITERATOR()
46 AT_DISPATCH_FLOATING_TYPES_AND2(
47 at::ScalarType::Half,
48 at::ScalarType::BFloat16,
49 iter.common_dtype(), "trigamma_cuda", [&]() {
50 jitted_gpu_kernel</*name=*/trigamma_name,
51 /*return_dtype=*/ scalar_t,
52 /*common_dtype=*/ scalar_t,
53 /*arity=*/ 1>(iter, trigamma_string);
54 });
55 #else
56 AT_DISPATCH_FLOATING_TYPES_AND2(
57 at::ScalarType::Half,
58 at::ScalarType::BFloat16,
59 iter.common_dtype(), "trigamma_cuda", [&]() {
60 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
61 return calc_trigamma(a);
62 });
63 });
64 #endif // AT_USE_JITERATOR()
65 }
66
67 CONSTEXPR_EXCEPT_WIN_CUDA char polygamma_name[] = "polygamma";
polygamma_kernel_cuda(TensorIteratorBase & iter,int64_t n)68 void polygamma_kernel_cuda(TensorIteratorBase& iter, int64_t n) {
69 if (n == 0) {
70 digamma_kernel_cuda(iter);
71 } else if (n == 1) {
72 trigamma_kernel_cuda(iter);
73 } else {
74 #if AT_USE_JITERATOR()
75 // TODO : `unary_jitted_gpu_kernel` for cleaner UX.
76 AT_DISPATCH_FLOATING_TYPES_AND2(
77 at::ScalarType::Half,
78 at::ScalarType::BFloat16,
79 iter.common_dtype(), "polygamma_cuda", [&]() {
80 jitted_gpu_kernel<
81 /*name=*/polygamma_name,
82 /*return_dtype=*/scalar_t,
83 /*common_dtype=*/scalar_t,
84 /*arity=*/1>(
85 iter,
86 polygamma_string,
87 /*scalar_pos=*/at::cuda::jit::BinaryFuncVariant::NoScalar,
88 /*scalar_val=*/0,
89 /*extra_args=*/std::make_tuple(n));
90 });
91 #else
92 AT_DISPATCH_FLOATING_TYPES_AND2(
93 at::ScalarType::Half,
94 at::ScalarType::BFloat16,
95 iter.common_dtype(), "polygamma_cuda", [&]() {
96 gpu_kernel(iter, [=] GPU_LAMBDA(scalar_t a) -> scalar_t {
97 return calc_polygamma<scalar_t, /*is_cuda=*/true>(a, static_cast<int>(n));
98 });
99 });
100 #endif // AT_USE_JITERATOR()
101 }
102 }
103
104 CONSTEXPR_EXCEPT_WIN_CUDA char lgamma_name[] = "lgamma_kernel";
lgamma_kernel_cuda(TensorIteratorBase & iter)105 void lgamma_kernel_cuda(TensorIteratorBase& iter) {
106 #if AT_USE_JITERATOR()
107 AT_DISPATCH_FLOATING_TYPES_AND2(
108 at::ScalarType::Half,
109 at::ScalarType::BFloat16,
110 iter.common_dtype(), "lgamma_cuda", [&]() {
111 jitted_gpu_kernel</*name=*/lgamma_name,
112 /*return_dtype=*/ scalar_t,
113 /*common_dtype=*/ scalar_t,
114 /*arity=*/ 1>(iter, lgamma_string);
115 });
116 #else
117 AT_DISPATCH_FLOATING_TYPES_AND2(
118 at::ScalarType::Half,
119 at::ScalarType::BFloat16,
120 iter.common_dtype(), "lgamma_cuda", [&]() {
121 gpu_kernel(iter, []GPU_LAMBDA(scalar_t a) -> scalar_t {
122 return ::lgamma(a);
123 });
124 });
125 #endif
126 }
127
128 REGISTER_DISPATCH(digamma_stub, &digamma_kernel_cuda);
129 REGISTER_DISPATCH(polygamma_stub, &polygamma_kernel_cuda);
130 REGISTER_DISPATCH(lgamma_stub, &lgamma_kernel_cuda);
131
132 } // namespace at::native
133