1 #define TORCH_ASSERT_NO_OPERATORS
2 #include <ATen/Dispatch.h>
3 #include <ATen/native/DispatchStub.h>
4 #include <ATen/native/cuda/Loops.cuh>
5 #include <ATen/native/TensorIterator.h>
6 #include <ATen/native/BinaryOps.h>
7 #include <ATen/OpMathType.h>
8 #include <c10/util/MathConstants.h>
9
10 // NOTE: CUDA on Windows requires that the enclosing function
11 // of a __device__ lambda not have internal linkage.
12
13 namespace at::native {
14
logaddexp_kernel_cuda(TensorIteratorBase & iter)15 void logaddexp_kernel_cuda(TensorIteratorBase& iter) {
16 AT_DISPATCH_FLOATING_TYPES_AND2(
17 ScalarType::BFloat16, ScalarType::Half,
18 iter.dtype(), "logaddexp_cuda",
19 [&]() {
20 using opmath_t = at::opmath_type<scalar_t>;
21 gpu_kernel(iter, [] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t {
22 const auto a = static_cast<opmath_t>(a_);
23 const auto b = static_cast<opmath_t>(b_);
24 if (::isinf(a) && a == b) {
25 return a;
26 } else {
27 const auto m = ::max(a, b);
28 return m + ::log1p(::exp(-::abs(a - b)));
29 }
30 });
31 });
32 }
33
logaddexp2_kernel_cuda(TensorIteratorBase & iter)34 void logaddexp2_kernel_cuda(TensorIteratorBase& iter) {
35 AT_DISPATCH_FLOATING_TYPES_AND2(
36 ScalarType::BFloat16, ScalarType::Half,
37 iter.dtype(), "logaddexp2_cuda",
38 [&]() {
39 using opmath_t = at::opmath_type<scalar_t>;
40 const auto inv_log_2 = static_cast<opmath_t>(1.0 / c10::ln_2<double>);
41 gpu_kernel(iter, [inv_log_2] GPU_LAMBDA (scalar_t a_, scalar_t b_) -> scalar_t {
42 const auto a = static_cast<opmath_t>(a_);
43 const auto b = static_cast<opmath_t>(b_);
44 if (::isinf(a) && a == b) {
45 return a;
46 } else {
47 const auto m = ::max(a, b);
48 return m + ::log1p(::exp2(-::abs(a - b))) * inv_log_2;
49 }
50 });
51 });
52 }
53
54 REGISTER_DISPATCH(logaddexp_stub, &logaddexp_kernel_cuda);
55 REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel_cuda);
56
57 } // namespace at::native
58