xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ActivationGluKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3 
4 #include <ATen/native/Activation.h>
5 
6 #include <cmath>
7 
8 #include <thrust/tuple.h>
9 
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18 
19 namespace at::native {
20 
21 // -----------------------------------
22 // glu forward
23 // -----------------------------------
glu_kernel(TensorIteratorBase & iter)24 void glu_kernel(TensorIteratorBase& iter) {
25   AT_DISPATCH_FLOATING_TYPES_AND2(
26       kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() {
27         using opmath_t = at::opmath_type<scalar_t>;
28         gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a_, scalar_t b_) -> scalar_t {
29           const opmath_t a = a_;
30           const opmath_t b = b_;
31           const opmath_t one = opmath_t(1);
32           const opmath_t sigmoid = one / (one + std::exp(-b));
33           return a * sigmoid;
34         });
35       });
36 }
37 
38 // -----------------------------------
39 // glu forward ad
40 // -----------------------------------
glu_jvp_kernel(TensorIteratorBase & iter)41 void glu_jvp_kernel(TensorIteratorBase& iter) {
42   AT_DISPATCH_FLOATING_TYPES_AND2(
43       kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() {
44         using opmath_t = at::opmath_type<scalar_t>;
45         gpu_kernel(
46             iter,
47             [] GPU_LAMBDA(
48                 scalar_t res_, scalar_t b_, scalar_t da_, scalar_t db_)
49                 -> scalar_t {
50               const opmath_t res = res_;
51               const opmath_t b = b_;
52               const opmath_t da = da_;
53               const opmath_t db = db_;
54               const opmath_t one = opmath_t(1);
55 
56               const opmath_t sig_b = one / (one + std::exp(-b));
57               return (da * sig_b + res * (db - sig_b * db));
58             });
59       });
60 }
61 
62 // -----------------------------------
63 // glu backward
64 // -----------------------------------
65 
66 // Byte offsets don't require multiplication by sizeof(T), so are slightly
67 // cheaper. For fixed offsets, this removes all penalty from 64-bit indexing.
68 template <typename T>
byte_offset(T * ptr,int64_t offset)69 __device__ T* byte_offset(T* ptr, int64_t offset) {
70   using byte_ptr_t = typename std::
71       conditional<std::is_const<T>::value, const char*, char*>::type;
72   return reinterpret_cast<T*>(reinterpret_cast<byte_ptr_t>(ptr) + offset);
73 }
74 
75 template <typename scalar_t, typename OffsetCalc>
glu_backward_kernel(int numel,scalar_t * gI,const scalar_t * I,const scalar_t * gO,OffsetCalc offset_calculator,int64_t gI_byte_offset,int64_t I_byte_offset)76 __global__ void glu_backward_kernel(
77     int numel,
78     scalar_t* gI,
79     const scalar_t* I,
80     const scalar_t* gO,
81     OffsetCalc offset_calculator,
82     int64_t gI_byte_offset,
83     int64_t I_byte_offset) {
84   using opmath_t = at::opmath_type<scalar_t>;
85 
86   const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
87   if (linear_index >= numel) {
88     return;
89   }
90   const auto offsets = offset_calculator.get(linear_index);
91 
92   // We explicitly iterate over the first half of the input tensor, and
93   // gI_byte_offset and I_byte_offset are the offsets to access the
94   // corresponding index in the second half of the tensor.
95   const opmath_t a = I[offsets[1]];
96   const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset);
97   const opmath_t gO_val = gO[offsets[2]];
98 
99   const auto one = opmath_t(1);
100   const opmath_t sigmoid = one / (one + std::exp(-b));
101 
102   auto* gA = gI + offsets[0];
103   *gA = sigmoid * gO_val;
104 
105   auto* gB = byte_offset(gA, gI_byte_offset);
106   *gB = (one - sigmoid) * sigmoid * gO_val * a;
107 }
108 
launch_glu_backward_kernel(const TensorIteratorBase & iter,int64_t gI_stride,int64_t I_stride)109 void launch_glu_backward_kernel(
110     const TensorIteratorBase& iter,
111     int64_t gI_stride,
112     int64_t I_stride) {
113   const auto N = iter.numel();
114   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
115       N > 0 && N <= std::numeric_limits<int32_t>::max());
116   const auto offset_calculator = make_element_offset_calculator<3>(iter);
117   constexpr int64_t block_size = 256;
118   const int64_t grid = (N + block_size - 1) / block_size;
119   const auto stream = at::cuda::getCurrentCUDAStream();
120 
121   AT_DISPATCH_FLOATING_TYPES_AND2(
122       kHalf, kBFloat16, iter.common_dtype(), "glu_backward_cuda", [&] {
123         auto gI = static_cast<scalar_t*>(iter.data_ptr(0));
124         auto I = static_cast<const scalar_t*>(iter.data_ptr(1));
125         auto gO = static_cast<const scalar_t*>(iter.data_ptr(2));
126         glu_backward_kernel<<<grid, block_size, 0, stream>>>(
127             N,
128             gI,
129             I,
130             gO,
131             offset_calculator,
132             gI_stride * sizeof(scalar_t),
133             I_stride * sizeof(scalar_t));
134         C10_CUDA_KERNEL_LAUNCH_CHECK();
135       });
136 }
137 
138 REGISTER_DISPATCH(glu_stub, &glu_kernel);
139 REGISTER_DISPATCH(glu_jvp_stub, &glu_jvp_kernel);
140 
141 } // namespace at::native
142