xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/LogAddExpKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7 #include <ATen/OpMathType.h>
8 #include <c10/util/MathConstants.h>
9 
10 // NOTE: CUDA on Windows requires that the enclosing function
11 // of a __device__ lambda not have internal linkage.
12 
13 namespace at::native {
14 
logaddexp_kernel_cuda(TensorIteratorBase & iter)15 void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
16   AT_DISPATCH_FLOATING_TYPES_AND2(
17       ScalarType::BFloat16, ScalarType::Half,
18       iter.dtype(), "logaddexp_cuda",
19       [&]() {
20         using opmath_t = at::opmath_type<scalar_t>;
21         gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t {
22           const auto a = static_cast<opmath_t>(a_);
23           const auto b = static_cast<opmath_t>(b_);
24           if (::isinf(a) && a == b) {
25             return a;
26           } else {
27             const auto m = ::max(a, b);
28             return m + ::log1p(::exp(-::abs(a - b)));
29           }
30         });
31       });
32 }
33 
logaddexp2_kernel_cuda(TensorIteratorBase & iter)34 void logaddexp2_kernel_cuda(TensorIteratorBase& iter) {
35   AT_DISPATCH_FLOATING_TYPES_AND2(
36       ScalarType::BFloat16, ScalarType::Half,
37       iter.dtype(), "logaddexp2_cuda",
38       [&]() {
39         using opmath_t = at::opmath_type<scalar_t>;
40         const auto inv_log_2 = static_cast<opmath_t>(1.0 / c10::ln_2<double>);
41         gpu_kernel(iter, [inv_log_2] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t {
42           const auto a = static_cast<opmath_t>(a_);
43           const auto b = static_cast<opmath_t>(b_);
44           if (::isinf(a) && a == b) {
45             return a;
46           } else {
47             const auto m = ::max(a, b);
48             return m + ::log1p(::exp2(-::abs(a - b))) * inv_log_2;
49           }
50         });
51       });
52 }
53 
54 REGISTER_DISPATCH(logaddexp_stub, &logaddexp_kernel_cuda);
55 REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda);
56 
57 } // namespace at::native
58