1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3
4 #include <ATen/native/Activation.h>
5
6 #include <cmath>
7
8 #include <thrust/tuple.h>
9
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18
19 namespace at::native {
20 namespace {
21
elu_kernel(TensorIteratorBase & iter,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale)22 void elu_kernel(
23 TensorIteratorBase& iter,
24 const Scalar& alpha,
25 const Scalar& scale,
26 const Scalar& input_scale) {
27 AT_DISPATCH_FLOATING_TYPES_AND2(
28 at::ScalarType::Half,
29 at::ScalarType::BFloat16,
30 iter.dtype(),
31 "elu_cuda",
32 [&]() {
33 using opmath_t = at::opmath_type<scalar_t>;
34 auto negcoef = alpha.to<opmath_t>() * scale.to<opmath_t>();
35 auto poscoef = scale.to<opmath_t>();
36 auto negiptcoef = input_scale.to<opmath_t>();
37 gpu_kernel(
38 iter,
39 [negcoef, poscoef, negiptcoef] GPU_LAMBDA(scalar_t a) -> scalar_t {
40 opmath_t aop = static_cast<opmath_t>(a);
41 return aop > 0 ? aop * poscoef
42 : std::expm1(aop * negiptcoef) * negcoef;
43 });
44 });
45 }
46
elu_backward_kernel(TensorIteratorBase & iter,const Scalar & alpha,const Scalar & scale,const Scalar & input_scale,bool is_result)47 void elu_backward_kernel(
48 TensorIteratorBase& iter,
49 const Scalar& alpha,
50 const Scalar& scale,
51 const Scalar& input_scale,
52 bool is_result) {
53 AT_DISPATCH_FLOATING_TYPES_AND2(
54 at::ScalarType::Half,
55 at::ScalarType::BFloat16,
56 iter.dtype(),
57 "elu_backward_cuda",
58 [&]() {
59 using opmath_t = at::opmath_type<scalar_t>;
60 auto negcoef = alpha.to<opmath_t>() * scale.to<opmath_t>();
61 auto poscoef = scale.to<opmath_t>();
62 auto negiptcoef = input_scale.to<opmath_t>();
63 gpu_kernel(
64 iter,
65 [negcoef, poscoef, negiptcoef, is_result] GPU_LAMBDA(
66 scalar_t a, scalar_t b) -> scalar_t {
67 opmath_t aop = static_cast<opmath_t>(a);
68 opmath_t bop = static_cast<opmath_t>(b);
69
70 if (is_result) {
71 return bop <= 0 ? aop * negiptcoef * (bop + negcoef)
72 : aop * poscoef;
73 } else {
74 return bop <= 0
75 ? aop * negiptcoef * negcoef * std::exp(bop * negiptcoef)
76 : aop * poscoef;
77 }
78 });
79 });
80 }
81 } // namespace
82
83 REGISTER_DISPATCH(elu_stub, &elu_kernel);
84 REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel);
85
86 } // namespace at::native
87