1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/TensorBase.h>
3 #include <ATen/OpMathType.h>
4 #include <ATen/Dispatch.h>
5
6 #include <ATen/native/cuda/ScanKernels.h>
7 #include <ATen/native/cuda/ScanUtils.cuh>
8
9 #include <cmath>
10 #include <limits>
11
12 namespace at::native {
13
14 // custom min and max to be used in logcumsumexp for complex arguments
15 template <typename scalar_t, bool min>
_logcumsumexp_minmax(const c10::complex<scalar_t> & x,const c10::complex<scalar_t> & y)16 __host__ __device__ c10::complex<scalar_t> _logcumsumexp_minmax(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
17 scalar_t xr = std::real(x);
18 scalar_t yr = std::real(y);
19 if (::isnan(yr) || (::isnan(std::imag(y)))) {
20 return y;
21 } else if (::isnan(xr) || (::isnan(std::imag(x)))) {
22 return x;
23 } else if (min) { // min
24 return (xr < yr) ? x : y;
25 } else { // max
26 return (xr >= yr) ? x : y;
27 }
28 }
29
30 template <typename scalar_t>
_log_add_exp_helper(const scalar_t & x,const scalar_t & y)31 __host__ __device__ scalar_t _log_add_exp_helper(const scalar_t& x, const scalar_t& y) {
32 // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
33 // Using the original expression: `at::_isnan(y) ? y : std::min(x, y)` causes an error in ROCM
34 auto isnan_x = at::_isnan(x);
35 auto isnan_y = at::_isnan(y);
36 scalar_t min = isnan_y ? y : (isnan_x ? x : std::min(x, y));
37 scalar_t max = isnan_y ? y : (isnan_x ? x : std::max(x, y));
38 if (min != max || ::isfinite(min)) {
39 // nan will be propagated here
40 return ::log1p(std::exp(min - max)) + max;
41 } else {
42 // special case to correctly handle infinite cases
43 return x;
44 }
45 }
46
47 template <typename scalar_t>
_fast_build_exp(const c10::complex<scalar_t> & x)48 __host__ __device__ c10::complex<scalar_t> _fast_build_exp(const c10::complex<scalar_t>& x) {
49 // complex exponential function, but implemented manually to get fast compilation time
50 // this function only handles the case where the x is finite (not inf nor nan)
51 auto xreal = std::real(x);
52 auto ximag = std::imag(x);
53 auto exp_x_abs = std::exp(xreal);
54 auto exp_x_real = exp_x_abs * std::cos(ximag);
55 auto exp_x_imag = exp_x_abs * std::sin(ximag);
56 return {exp_x_real, exp_x_imag};
57 }
58
59 template <typename scalar_t>
_fast_build_exp_inf(const c10::complex<scalar_t> & x)60 __host__ __device__ c10::complex<scalar_t> _fast_build_exp_inf(const c10::complex<scalar_t>& x) {
61 // complex exponential function, but implemented manually to get fast compilation time
62 // this function only handles the case where the real part of x is infinite
63 auto ximag = std::imag(x);
64 auto exp_x_abs = std::numeric_limits<scalar_t>::infinity();
65 auto sin = std::sin(ximag);
66 auto cos = std::cos(ximag);
67 // special case if the angle is exactly the multiple of pi/2
68 auto exp_x_real = (cos == 0) ? (scalar_t)0.0 : exp_x_abs * cos;
69 auto exp_x_imag = (sin == 0) ? (scalar_t)0.0 : exp_x_abs * sin;
70 return {exp_x_real, exp_x_imag};
71 }
72
73 template <typename scalar_t>
_log_add_exp_helper(const c10::complex<scalar_t> & x,const c10::complex<scalar_t> & y)74 __host__ __device__ c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
75 c10::complex<scalar_t> min = _logcumsumexp_minmax<scalar_t, /*min=*/true>(x, y);
76 c10::complex<scalar_t> max = _logcumsumexp_minmax<scalar_t, /*min=*/false>(x, y);
77 scalar_t min_real = std::real(min);
78 scalar_t max_real = std::real(max);
79
80 if (::isnan(min_real) || ::isnan(std::imag(min))) {
81 // handling the "infectious" NaNs
82 return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
83 }
84 else if ((!::isfinite(min_real)) && (min_real == max_real)) {
85 if (min_real < 0) {
86 // handle the -inf case, the imaginary part here does not really matter as the exp(value)
87 // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
88 // It does not matter if we're taking the exp of this value
89 return min;
90 } else {
91 // handle the +inf case, we don't need the special precision for log1p for small values
92 // and to avoid producing nan in case of real(max) == real(min) == +inf
93 auto exp_min = _fast_build_exp_inf(min);
94 auto exp_max = _fast_build_exp_inf(max);
95 return ::log1p(exp_min + exp_max - 1); // log1p(x - 1) builds faster than log
96 }
97 } else {
98 auto minmax = min - max;
99 auto exp_minmax = _fast_build_exp(minmax);
100 return ::log1p(exp_minmax) + max;
101 }
102 }
103
launch_logcumsumexp_cuda_kernel(const TensorBase & result,const TensorBase & self,int64_t dim)104 void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
105 // Compile time for CUDA-11.4 is 3x slower than with CUDA-11.6+, specifically for complex numbers
106 #if defined(FBCODE_CAFFE2) || defined(OVRSOURCE)
107 #define _LCME_DISPATCH AT_DISPATCH_FLOATING_TYPES_AND2
108 #else
109 #define _LCME_DISPATCH AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2
110 #endif
111 _LCME_DISPATCH(ScalarType::Half, ScalarType::BFloat16,
112 self.scalar_type(), "logcumsumexp_cuda",
113 [&]() {
114 using opmath_t = at::opmath_type<scalar_t>;
115 scalar_t init = -std::numeric_limits<scalar_t>::infinity();
116 auto log_add_exp = [] C10_HOST_DEVICE (const scalar_t x_, const scalar_t y_) -> scalar_t {
117 const opmath_t x{x_}, y{y_};
118 return _log_add_exp_helper(x, y);
119 };
120 scan_dim<scalar_t>(self, result, dim, init, log_add_exp);
121 });
122 }
123
124 } // namespace at::native
125