1 #define TORCH_ASSERT_NO_OPERATORS
2 #define _USE_MATH_DEFINES
3
4 #include <ATen/native/Activation.h>
5
6 #include <cmath>
7
8 #include <thrust/tuple.h>
9
10 #include <ATen/AccumulateType.h>
11 #include <ATen/Dispatch.h>
12 #include <ATen/core/TensorBase.h>
13 #include <c10/core/Scalar.h>
14 #include <c10/cuda/CUDAMathCompat.h>
15 #include <ATen/cuda/ApplyGridUtils.cuh>
16 #include <ATen/cuda/detail/OffsetCalculator.cuh>
17 #include <ATen/native/cuda/Loops.cuh>
18
19 namespace at::native {
20
21 // -----------------------------------
22 // glu forward
23 // -----------------------------------
glu_kernel(TensorIteratorBase & iter)24 void glu_kernel(TensorIteratorBase& iter) {
25 AT_DISPATCH_FLOATING_TYPES_AND2(
26 kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() {
27 using opmath_t = at::opmath_type<scalar_t>;
28 gpu_kernel(iter, [] GPU_LAMBDA(scalar_t a_, scalar_t b_) -> scalar_t {
29 const opmath_t a = a_;
30 const opmath_t b = b_;
31 const opmath_t one = opmath_t(1);
32 const opmath_t sigmoid = one / (one + std::exp(-b));
33 return a * sigmoid;
34 });
35 });
36 }
37
38 // -----------------------------------
39 // glu forward ad
40 // -----------------------------------
glu_jvp_kernel(TensorIteratorBase & iter)41 void glu_jvp_kernel(TensorIteratorBase& iter) {
42 AT_DISPATCH_FLOATING_TYPES_AND2(
43 kHalf, kBFloat16, iter.dtype(), "glu_cuda", [&]() {
44 using opmath_t = at::opmath_type<scalar_t>;
45 gpu_kernel(
46 iter,
47 [] GPU_LAMBDA(
48 scalar_t res_, scalar_t b_, scalar_t da_, scalar_t db_)
49 -> scalar_t {
50 const opmath_t res = res_;
51 const opmath_t b = b_;
52 const opmath_t da = da_;
53 const opmath_t db = db_;
54 const opmath_t one = opmath_t(1);
55
56 const opmath_t sig_b = one / (one + std::exp(-b));
57 return (da * sig_b + res * (db - sig_b * db));
58 });
59 });
60 }
61
62 // -----------------------------------
63 // glu backward
64 // -----------------------------------
65
66 // Byte offsets don't require multiplication by sizeof(T), so are slightly
67 // cheaper. For fixed offsets, this removes all penalty from 64-bit indexing.
68 template <typename T>
byte_offset(T * ptr,int64_t offset)69 __device__ T* byte_offset(T* ptr, int64_t offset) {
70 using byte_ptr_t = typename std::
71 conditional<std::is_const<T>::value, const char*, char*>::type;
72 return reinterpret_cast<T*>(reinterpret_cast<byte_ptr_t>(ptr) + offset);
73 }
74
75 template <typename scalar_t, typename OffsetCalc>
glu_backward_kernel(int numel,scalar_t * gI,const scalar_t * I,const scalar_t * gO,OffsetCalc offset_calculator,int64_t gI_byte_offset,int64_t I_byte_offset)76 __global__ void glu_backward_kernel(
77 int numel,
78 scalar_t* gI,
79 const scalar_t* I,
80 const scalar_t* gO,
81 OffsetCalc offset_calculator,
82 int64_t gI_byte_offset,
83 int64_t I_byte_offset) {
84 using opmath_t = at::opmath_type<scalar_t>;
85
86 const uint32_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
87 if (linear_index >= numel) {
88 return;
89 }
90 const auto offsets = offset_calculator.get(linear_index);
91
92 // We explicitly iterate over the first half of the input tensor, and
93 // gI_byte_offset and I_byte_offset are the offsets to access the
94 // corresponding index in the second half of the tensor.
95 const opmath_t a = I[offsets[1]];
96 const opmath_t b = *byte_offset(I + offsets[1], I_byte_offset);
97 const opmath_t gO_val = gO[offsets[2]];
98
99 const auto one = opmath_t(1);
100 const opmath_t sigmoid = one / (one + std::exp(-b));
101
102 auto* gA = gI + offsets[0];
103 *gA = sigmoid * gO_val;
104
105 auto* gB = byte_offset(gA, gI_byte_offset);
106 *gB = (one - sigmoid) * sigmoid * gO_val * a;
107 }
108
launch_glu_backward_kernel(const TensorIteratorBase & iter,int64_t gI_stride,int64_t I_stride)109 void launch_glu_backward_kernel(
110 const TensorIteratorBase& iter,
111 int64_t gI_stride,
112 int64_t I_stride) {
113 const auto N = iter.numel();
114 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
115 N > 0 && N <= std::numeric_limits<int32_t>::max());
116 const auto offset_calculator = make_element_offset_calculator<3>(iter);
117 constexpr int64_t block_size = 256;
118 const int64_t grid = (N + block_size - 1) / block_size;
119 const auto stream = at::cuda::getCurrentCUDAStream();
120
121 AT_DISPATCH_FLOATING_TYPES_AND2(
122 kHalf, kBFloat16, iter.common_dtype(), "glu_backward_cuda", [&] {
123 auto gI = static_cast<scalar_t*>(iter.data_ptr(0));
124 auto I = static_cast<const scalar_t*>(iter.data_ptr(1));
125 auto gO = static_cast<const scalar_t*>(iter.data_ptr(2));
126 glu_backward_kernel<<<grid, block_size, 0, stream>>>(
127 N,
128 gI,
129 I,
130 gO,
131 offset_calculator,
132 gI_stride * sizeof(scalar_t),
133 I_stride * sizeof(scalar_t));
134 C10_CUDA_KERNEL_LAUNCH_CHECK();
135 });
136 }
137
138 REGISTER_DISPATCH(glu_stub, &glu_kernel);
139 REGISTER_DISPATCH(glu_jvp_stub, &glu_jvp_kernel);
140
141 } // namespace at::native
142