xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/LogcumsumexpKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/core/TensorBase.h>
3 #include <ATen/OpMathType.h>
4 #include <ATen/Dispatch.h>
5 
6 #include <ATen/native/cuda/ScanKernels.h>
7 #include <ATen/native/cuda/ScanUtils.cuh>
8 
9 #include <cmath>
10 #include <limits>
11 
12 namespace at::native {
13 
14 // custom min and max to be used in logcumsumexp for complex arguments
15 template <typename scalar_t, bool min>
_logcumsumexp_minmax(const c10::complex<scalar_t> & x,const c10::complex<scalar_t> & y)16 __host__ __device__ c10::complex<scalar_t> _logcumsumexp_minmax(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
17   scalar_t xr = std::real(x);
18   scalar_t yr = std::real(y);
19   if (::isnan(yr) || (::isnan(std::imag(y)))) {
20     return y;
21   } else if (::isnan(xr) || (::isnan(std::imag(x)))) {
22     return x;
23   } else if (min) { // min
24     return (xr < yr) ? x : y;
25   } else { // max
26     return (xr >= yr) ? x : y;
27   }
28 }
29 
30 template <typename scalar_t>
_log_add_exp_helper(const scalar_t & x,const scalar_t & y)31 __host__ __device__ scalar_t _log_add_exp_helper(const scalar_t& x, const scalar_t& y) {
32   // Reference : https://www.tensorflow.org/api_docs/python/tf/math/cumulative_logsumexp
33   // Using the original expression: `at::_isnan(y) ? y : std::min(x, y)` causes an error in ROCM
34   auto isnan_x = at::_isnan(x);
35   auto isnan_y = at::_isnan(y);
36   scalar_t min = isnan_y ? y : (isnan_x ? x : std::min(x, y));
37   scalar_t max = isnan_y ? y : (isnan_x ? x : std::max(x, y));
38   if (min != max || ::isfinite(min)) {
39     // nan will be propagated here
40     return ::log1p(std::exp(min - max)) + max;
41   } else {
42     // special case to correctly handle infinite cases
43     return x;
44   }
45 }
46 
47 template <typename scalar_t>
_fast_build_exp(const c10::complex<scalar_t> & x)48 __host__ __device__ c10::complex<scalar_t> _fast_build_exp(const c10::complex<scalar_t>& x) {
49   // complex exponential function, but implemented manually to get fast compilation time
50   // this function only handles the case where the x is finite (not inf nor nan)
51   auto xreal = std::real(x);
52   auto ximag = std::imag(x);
53   auto exp_x_abs = std::exp(xreal);
54   auto exp_x_real = exp_x_abs * std::cos(ximag);
55   auto exp_x_imag = exp_x_abs * std::sin(ximag);
56   return {exp_x_real, exp_x_imag};
57 }
58 
59 template <typename scalar_t>
_fast_build_exp_inf(const c10::complex<scalar_t> & x)60 __host__ __device__ c10::complex<scalar_t> _fast_build_exp_inf(const c10::complex<scalar_t>& x) {
61   // complex exponential function, but implemented manually to get fast compilation time
62   // this function only handles the case where the real part of x is infinite
63   auto ximag = std::imag(x);
64   auto exp_x_abs = std::numeric_limits<scalar_t>::infinity();
65   auto sin = std::sin(ximag);
66   auto cos = std::cos(ximag);
67   // special case if the angle is exactly the multiple of pi/2
68   auto exp_x_real = (cos == 0) ? (scalar_t)0.0 : exp_x_abs * cos;
69   auto exp_x_imag = (sin == 0) ? (scalar_t)0.0 : exp_x_abs * sin;
70   return {exp_x_real, exp_x_imag};
71 }
72 
73 template <typename scalar_t>
_log_add_exp_helper(const c10::complex<scalar_t> & x,const c10::complex<scalar_t> & y)74 __host__ __device__ c10::complex<scalar_t> _log_add_exp_helper(const c10::complex<scalar_t>& x, const c10::complex<scalar_t>& y) {
75   c10::complex<scalar_t> min = _logcumsumexp_minmax<scalar_t, /*min=*/true>(x, y);
76   c10::complex<scalar_t> max = _logcumsumexp_minmax<scalar_t, /*min=*/false>(x, y);
77   scalar_t min_real = std::real(min);
78   scalar_t max_real = std::real(max);
79 
80   if (::isnan(min_real) || ::isnan(std::imag(min))) {
81     // handling the "infectious" NaNs
82     return {std::numeric_limits<scalar_t>::quiet_NaN(), std::numeric_limits<scalar_t>::quiet_NaN()};
83   }
84   else if ((!::isfinite(min_real)) && (min_real == max_real)) {
85     if (min_real < 0) {
86       // handle the -inf case, the imaginary part here does not really matter as the exp(value)
87       // will be around 0.0 and the angle (i.e. the imaginary part) cannot be determined.
88       // It does not matter if we're taking the exp of this value
89       return min;
90     } else {
91       // handle the +inf case, we don't need the special precision for log1p for small values
92       // and to avoid producing nan in case of real(max) == real(min) == +inf
93       auto exp_min = _fast_build_exp_inf(min);
94       auto exp_max = _fast_build_exp_inf(max);
95       return ::log1p(exp_min + exp_max - 1);  // log1p(x - 1) builds faster than log
96     }
97   } else {
98     auto minmax = min - max;
99     auto exp_minmax = _fast_build_exp(minmax);
100     return ::log1p(exp_minmax) + max;
101   }
102 }
103 
launch_logcumsumexp_cuda_kernel(const TensorBase & result,const TensorBase & self,int64_t dim)104 void launch_logcumsumexp_cuda_kernel(const TensorBase& result, const TensorBase& self, int64_t dim) {
105 // Compile time for CUDA-11.4 is 3x slower than with CUDA-11.6+, specifically for complex numbers
106 #if defined(FBCODE_CAFFE2) || defined(OVRSOURCE)
107 #define _LCME_DISPATCH AT_DISPATCH_FLOATING_TYPES_AND2
108 #else
109 #define _LCME_DISPATCH AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2
110 #endif
111   _LCME_DISPATCH(ScalarType::Half, ScalarType::BFloat16,
112       self.scalar_type(), "logcumsumexp_cuda",
113       [&]() {
114         using opmath_t = at::opmath_type<scalar_t>;
115         scalar_t init = -std::numeric_limits<scalar_t>::infinity();
116         auto log_add_exp = [] C10_HOST_DEVICE (const scalar_t x_, const scalar_t y_) -> scalar_t {
117           const opmath_t x{x_}, y{y_};
118           return _log_add_exp_helper(x, y);
119         };
120         scan_dim<scalar_t>(self, result, dim, init, log_add_exp);
121       });
122 }
123 
124 } // namespace at::native
125