xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ActivationSiluKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 #include <c10/util/complex.h>
19 
20 namespace at::native {
21 namespace {
22 
silu_kernel(TensorIteratorBase & iter)23 void silu_kernel(TensorIteratorBase& iter) {
24   AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(
25       at::ScalarType::Half,
26       at::ScalarType::BFloat16,
27       iter.dtype(),
28       "silu_cuda",
29       [&]() {
30         gpu_kernel(iter, [] GPU_LAMBDA(scalar_t x) -> scalar_t {
31           using opmath_t = at::opmath_type<scalar_t>;
32           const opmath_t x_acc = static_cast<opmath_t>(x);
33           return x_acc / (opmath_t(1) + ::exp(-x_acc));
34         });
35       });
36 }
37 
silu_backward_kernel(TensorIteratorBase & iter)38 void silu_backward_kernel(TensorIteratorBase& iter) {
39   AT_DISPATCH_FLOATING_TYPES_AND2(
40       at::ScalarType::Half,
41       at::ScalarType::BFloat16,
42       iter.dtype(),
43       "silu_backward_cuda",
44       [&]() {
45         gpu_kernel(iter, [] GPU_LAMBDA(scalar_t dy, scalar_t x) -> scalar_t {
46           using opmath_t = at::opmath_type<scalar_t>;
47           const opmath_t dy_acc = static_cast<opmath_t>(dy);
48           const opmath_t x_acc = static_cast<opmath_t>(x);
49           const opmath_t s_acc =
50               opmath_t(1) / (opmath_t(1) + c10::cuda::compat::exp(-x_acc));
51           return dy_acc * s_acc * (opmath_t(1) + x_acc * (opmath_t(1) - s_acc));
52         });
53       });
54 }
55 } // namespace
56 
57 REGISTER_DISPATCH(silu_stub, &silu_kernel);
58 REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
59 
60 } // namespace at::native
61