xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ActivationEluKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 
19 namespace at::native {
20 namespace {
21 
elu_kernel(TensorIteratorBase & iter,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale)22 void elu_kernel(
23     TensorIteratorBase& iter,
24     const Scalar& alpha,
25     const Scalar& scale,
26     const Scalar& input_scale) {
27   AT_DISPATCH_FLOATING_TYPES_AND2(
28       at::ScalarType::Half,
29       at::ScalarType::BFloat16,
30       iter.dtype(),
31       "elu_cuda",
32       [&]() {
33         using opmath_t = at::opmath_type<scalar_t>;
34         auto negcoef = alpha.to<opmath_t>() * scale.to<opmath_t>();
35         auto poscoef = scale.to<opmath_t>();
36         auto negiptcoef = input_scale.to<opmath_t>();
37         gpu_kernel(
38             iter,
39             [negcoef, poscoef, negiptcoef] GPU_LAMBDA(scalar_t a) -> scalar_t {
40               opmath_t aop = static_cast<opmath_t>(a);
41               return aop > 0 ? aop * poscoef
42                              : std::expm1(aop * negiptcoef) * negcoef;
43             });
44       });
45 }
46 
elu_backward_kernel(TensorIteratorBase & iter,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale,bool is_result)47 void elu_backward_kernel(
48     TensorIteratorBase& iter,
49     const Scalar& alpha,
50     const Scalar& scale,
51     const Scalar& input_scale,
52     bool is_result) {
53   AT_DISPATCH_FLOATING_TYPES_AND2(
54       at::ScalarType::Half,
55       at::ScalarType::BFloat16,
56       iter.dtype(),
57       "elu_backward_cuda",
58       [&]() {
59         using opmath_t = at::opmath_type<scalar_t>;
60         auto negcoef = alpha.to<opmath_t>() * scale.to<opmath_t>();
61         auto poscoef = scale.to<opmath_t>();
62         auto negiptcoef = input_scale.to<opmath_t>();
63         gpu_kernel(
64             iter,
65             [negcoef, poscoef, negiptcoef, is_result] GPU_LAMBDA(
66                 scalar_t a, scalar_t b) -> scalar_t {
67               opmath_t aop = static_cast<opmath_t>(a);
68               opmath_t bop = static_cast<opmath_t>(b);
69 
70               if (is_result) {
71                 return bop <= 0 ? aop * negiptcoef * (bop + negcoef)
72                                 : aop * poscoef;
73               } else {
74                 return bop <= 0
75                     ? aop * negiptcoef * negcoef * std::exp(bop * negiptcoef)
76                     : aop * poscoef;
77               }
78             });
79       });
80 }
81 } // namespace
82 
83 REGISTER_DISPATCH(elu_stub, &elu_kernel);
84 REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel);
85 
86 } // namespace at::native
87