1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3
4 #include <ATen/native/Activation.h>
5
6 #include <cmath>
7
8 #include <thrust/tuple.h>
9
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 #include <c10/util/complex.h>
19
20 namespace at::native {
21 namespace {
22
silu_kernel(TensorIteratorBase & iter)23 void silu_kernel(TensorIteratorBase& iter) {
24 AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
25 at::ScalarType::Half,
26 at::ScalarType::BFloat16,
27 iter.dtype(),
28 "silu_cuda",
29 [&]() {
30 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
31 using opmath_t = at::opmath_type<scalar_t>;
32 const opmath_t x_acc = static_cast<opmath_t>(x);
33 return x_acc / (opmath_t(1) + ::exp(-x_acc));
34 });
35 });
36 }
37
silu_backward_kernel(TensorIteratorBase & iter)38 void silu_backward_kernel(TensorIteratorBase& iter) {
39 AT_DISPATCH_FLOATING_TYPES_AND2(
40 at::ScalarType::Half,
41 at::ScalarType::BFloat16,
42 iter.dtype(),
43 "silu_backward_cuda",
44 [&]() {
45 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
46 using opmath_t = at::opmath_type<scalar_t>;
47 const opmath_t dy_acc = static_cast<opmath_t>(dy);
48 const opmath_t x_acc = static_cast<opmath_t>(x);
49 const opmath_t s_acc =
50 opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc));
51 return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc));
52 });
53 });
54 }
55 } // namespace
56
57 REGISTER_DISPATCH(silu_stub, &silu_kernel);
58 REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
59
60 } // namespace at::native
61