xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/EmbeddingBackwardKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
3 #include <ATen/cuda/Atomic.cuh>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/cub.cuh>
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/native/cuda/SortingCommon.cuh>
10 
11 #include <c10/macros/Macros.h>
12 
13 #if CUB_SUPPORTS_UNIQUE_BY_KEY()
14 #include <thrust/iterator/counting_iterator.h>
15 #endif
16 
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/zeros.h>
22 #endif
23 
24 namespace at::native {
25 
26 namespace {
27 
28 /* This code computes the sum of the weights in two-steps:
29   1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indeces`
30   2) Each partial-sum from 1) are summed and scatter into `grad_weight`
31 
32   Notice, `NROWS_PER_THREAD` impacts the Achieved Occupancy of the
33   kernel execution. If it is high, the size of the thread blocks will be
34   too small to achieve good occupancy. Similarly, a very low value will
35   make the size of the thread blocks in the final sum in step 2) too small.
36 */
37 constexpr int NROWS_PER_THREAD = 10;
38 
39 // Fast ceil division (no overflow checking)
40 __host__ __device__ __forceinline__
ceil_div(int64_t x,int64_t y)41 int64_t ceil_div(int64_t x, int64_t y) {
42   return (x + y - 1) / y;
43 }
44 
45 template <typename index_t>
46 __global__
krn_partials_per_segment(index_t * ret,const index_t * segment_offsets,const int64_t * num_of_segments_ptr,int64_t numel)47 void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets,
48                               const int64_t *num_of_segments_ptr, int64_t numel) {
49   int64_t num_of_segments = *num_of_segments_ptr;
50   const int id = blockIdx.x * blockDim.x + threadIdx.x;
51   if(id < num_of_segments) {
52     const int64_t idx_start = segment_offsets[id];
53     const int64_t idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
54     const int64_t size = idx_end - idx_start;
55     ret[id] = ceil_div(size, NROWS_PER_THREAD);
56   }
57 }
58 
59 template <typename index_t>
60 __global__
krn_partial_segment_offset(index_t * ret,const index_t * partials_per_segment,const index_t * partials_per_segment_offset,const index_t * segment_offsets,const int64_t * num_of_segments_ptr)61 void krn_partial_segment_offset(
62         index_t *ret,
63         const index_t *partials_per_segment,
64         const index_t *partials_per_segment_offset,
65         const index_t *segment_offsets,
66         const int64_t *num_of_segments_ptr) {
67   int64_t num_of_segments = *num_of_segments_ptr;
68   const int id = blockIdx.x * blockDim.x + threadIdx.x;
69   if(id < num_of_segments) {
70     index_t idx = partials_per_segment_offset[id];
71     const index_t num_partials = partials_per_segment[id];
72     const index_t segment_offset = segment_offsets[id];
73     for (int64_t i=0; i<num_partials; ++i) {
74       ret[idx++] = segment_offset + i * NROWS_PER_THREAD;
75     }
76   }
77 }
78 
79 
80 template <typename scalar_t, typename index_t>
compute_grad_weight_bags(const index_t * indices,const scalar_t * gradOutput,const index_t * offset2bag,const index_t * count,ptrdiff_t numel,int64_t stride,int mode_mean,const index_t * bag_size,const scalar_t * per_sample_weights,int64_t per_sample_weights_stride,const index_t * segment_offsets,const int64_t * num_of_segments_ptr,acc_type<scalar_t,true> * grad_weight_per_segment,const int64_t stride_warped)81 __global__ void compute_grad_weight_bags(
82     const index_t *indices, const scalar_t *gradOutput,
83     const index_t *offset2bag, const index_t *count, ptrdiff_t numel,
84     int64_t stride, int mode_mean, const index_t *bag_size,
85     const scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
86     const index_t* segment_offsets, const int64_t *num_of_segments_ptr,
87     acc_type<scalar_t, true> *grad_weight_per_segment,
88     const int64_t stride_warped) {
89 
90   int64_t num_of_segments = *num_of_segments_ptr;
91   const int gid = blockIdx.x * blockDim.x + threadIdx.x;
92   const int id = gid / stride_warped;
93   const int startFeature = gid % stride_warped;
94   if (startFeature >= stride) {
95     return;
96   }
97   if (id >= num_of_segments) {
98     return;
99   }
100   const int idx_begin = segment_offsets[id];
101   const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
102 
103   acc_type<scalar_t, true> weight = 0;
104   for (int idx=idx_begin; idx < idx_end; ++idx) {
105     const int origRow = indices[idx];
106     const int seq_number = offset2bag[origRow];
107     const int gradOutputRow = seq_number * stride;
108 
109     acc_type<scalar_t, true> scale = count ? 1.0 / count[idx] : 1.0;
110     if (per_sample_weights) {
111       scale *= per_sample_weights[origRow * per_sample_weights_stride];
112     }
113 
114     acc_type<scalar_t, true> gradient = gradOutput[gradOutputRow + startFeature];
115     if (mode_mean) {
116       gradient /= bag_size[seq_number];
117     }
118     weight += gradient * scale;
119   }
120   grad_weight_per_segment[id * stride + startFeature] = weight;
121 }
122 
123 template <typename scalar_t, typename index_t>
compute_grad_weight(const index_t * indices,const scalar_t * gradOutput,const index_t * count,ptrdiff_t numel,int64_t stride,const index_t * segment_offsets,const int64_t * num_of_segments_ptr,acc_type<scalar_t,true> * grad_weight_per_segment,const int64_t stride_warped)124 __global__ void compute_grad_weight(
125     const index_t *indices,
126     const scalar_t *gradOutput,
127     const index_t *count,
128     ptrdiff_t numel,
129     int64_t stride,
130     const index_t* segment_offsets,
131     const int64_t *num_of_segments_ptr,
132     acc_type<scalar_t, true> *grad_weight_per_segment,
133     const int64_t stride_warped) {
134 
135   int64_t num_of_segments = *num_of_segments_ptr;
136   using accscalar_t = acc_type<scalar_t, true>;
137   const int gid = blockIdx.x * blockDim.x + threadIdx.x;
138   const int id = gid / stride_warped;
139   const int startFeature = gid % stride_warped;
140   if (startFeature >= stride) {
141     return;
142   }
143   if (id >= num_of_segments) {
144     return;
145   }
146   const int idx_begin = segment_offsets[id];
147   const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
148 
149   accscalar_t weight = 0;
150   for (int idx=idx_begin; idx < idx_end; ++idx) {
151     const index_t target_row = indices[idx];
152     const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
153     weight += gradOutput[target_row * stride + startFeature] * scale;
154   }
155   grad_weight_per_segment[id * stride + startFeature] = weight;
156 }
157 
158 // This kernel assumes that all input tensors are contiguous.
159 template <typename scalar_t, typename index_t>
sum_and_scatter(const index_t * input,scalar_t * gradWeight,int64_t stride,const index_t * segment_offsets,const int64_t * num_of_segments_ptr,const acc_type<scalar_t,true> * grad_weight_per_segment,const index_t * segment_sizes_offsets,const int64_t * num_of_partial_segments_ptr,const int64_t padding_idx,const int64_t stride_warped)160 __global__ void sum_and_scatter(
161     const index_t *input, scalar_t *gradWeight, int64_t stride,
162     const index_t* segment_offsets, const int64_t *num_of_segments_ptr,
163     const acc_type<scalar_t, true> *grad_weight_per_segment,
164     const index_t *segment_sizes_offsets, const int64_t *num_of_partial_segments_ptr,
165     const int64_t padding_idx,
166     const int64_t stride_warped) {
167 
168   int64_t num_of_segments = *num_of_segments_ptr;
169   int64_t num_of_partial_segments = *num_of_partial_segments_ptr;
170   const int gid = blockIdx.x * blockDim.x + threadIdx.x;
171   const int id = gid / stride_warped;
172   const int startFeature = gid % stride_warped;
173   if (startFeature >= stride) {
174     return;
175   }
176   if (id >= num_of_segments) {
177     return;
178   }
179 
180   const int idx_begin = segment_sizes_offsets[id];
181   const int idx_end = (id == num_of_segments-1)?num_of_partial_segments:segment_sizes_offsets[id+1];
182   acc_type<scalar_t, true> weight = 0;
183   for (int idx=idx_begin; idx < idx_end; ++idx) {
184     weight += grad_weight_per_segment[idx*stride + startFeature];
185   }
186   int64_t target_row = input[segment_offsets[id]];
187   if (target_row != padding_idx) {
188     gradWeight[target_row * stride + startFeature] = weight;
189   }
190 }
191 
192 template<typename index_t>
compute_num_of_partial_segments(const index_t * partials_per_segment,const index_t * partials_per_segment_offset,const int64_t * num_of_segments_ptr,int64_t * output)193 __global__ void compute_num_of_partial_segments(const index_t *partials_per_segment, const index_t *partials_per_segment_offset, const int64_t *num_of_segments_ptr, int64_t *output) {
194   int64_t num_of_segments = *num_of_segments_ptr;
195   *output = partials_per_segment[num_of_segments-1] +
196             partials_per_segment_offset[num_of_segments-1];
197 }
198 
199 #if !CUB_SUPPORTS_UNIQUE_BY_KEY()
write_num_of_segments_for_legacy_thrust_path(int64_t * num_of_segments_ptr,int64_t num_of_segments)200 __global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
201   *num_of_segments_ptr = num_of_segments;
202 }
203 #endif
204 
205 } // anon namespace
206 
207 #if !CUB_SUPPORTS_UNIQUE_BY_KEY()
208 template<typename index_t>
209 int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
210 #endif
211 
embedding_backward_cuda_kernel(const Tensor & grad,const Tensor & orig_indices,const Tensor & sorted_indices,const Tensor & count,int64_t num_weights,int padding_idx,bool mode_mean,const Tensor & offset2bag,const Tensor & bag_size,const Tensor & per_sample_weights)212 Tensor embedding_backward_cuda_kernel(
213         const Tensor &grad,
214         const Tensor &orig_indices,
215         const Tensor &sorted_indices,
216         const Tensor &count,
217         int64_t num_weights,
218         int padding_idx,
219         bool mode_mean,
220         const Tensor &offset2bag,
221         const Tensor &bag_size,
222         const Tensor &per_sample_weights) {
223 
224   auto stream = at::cuda::getCurrentCUDAStream();
225   const ptrdiff_t numel = sorted_indices.numel();
226 
227   auto grad_weight = at::zeros({num_weights, grad.size(-1)}, grad.options());
228   const int64_t stride = grad_weight.stride(0);
229 
230   // Compute the number of segments and their start position so that we do not have to
231   // spawn a warp per index. In this context, a segment is a number of rows that should
232   // be summarized.
233   // Unit: index in `sorted_indices` and `orig_indices`
234   auto segment_offsets = at::empty({numel}, orig_indices.options());
235   auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
236   int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>();
237 #if !CUB_SUPPORTS_UNIQUE_BY_KEY()
238   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
239     int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
240     write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
241     C10_CUDA_KERNEL_LAUNCH_CHECK();
242   });
243 #else
244   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
245     cuda::cub::unique_by_key(
246       sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0),
247       segment_offsets.mutable_data_ptr<index_t>(),
248       num_of_segments_ptr, sorted_indices.numel());
249   });
250 #endif
251 
252   int64_t max_segments = std::min<int64_t>(numel, num_weights);
253 
254   AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
255     // We split the segments up into sizes of `NROWS_PER_THREAD`
256     // Compute the number partial-segments per segment (some partial-segments
257     // may not be the full `NROWS_PER_THREAD` number of rows)
258     auto partials_per_segment = at::empty({max_segments}, orig_indices.options());
259     {
260       krn_partials_per_segment<<<ceil_div(max_segments, 32), 32, 0, stream>>> (
261               partials_per_segment.mutable_data_ptr<index_t>(),
262               segment_offsets.const_data_ptr<index_t>(),
263               num_of_segments_ptr,
264               numel);
265       C10_CUDA_KERNEL_LAUNCH_CHECK();
266     }
267 
268     // In order to compute `partial_segment_offset`, which is the start index
269     // of each partial-segment in `sorted_indices`, we need to compute the
270     // start position of each _segment_ in `partial_segment_offset`.
271     // Unit: index in `partial_segment_offset`
272     auto partials_per_segment_offset = at::empty({max_segments}, orig_indices.options());
273     cuda::cub::exclusive_sum(
274         partials_per_segment.const_data_ptr<index_t>(),
275         partials_per_segment_offset.mutable_data_ptr<index_t>(),
276         max_segments);
277 
278     // The total number of partial-segments is the sum of `partials_per_segment_offset`
279     auto num_of_partial_segments_tensor = at::empty({}, grad.options().dtype(kLong));
280     int64_t *num_of_partial_segments_ptr = num_of_partial_segments_tensor.mutable_data_ptr<int64_t>();
281     compute_num_of_partial_segments<index_t><<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(
282       partials_per_segment.const_data_ptr<index_t>(),
283       partials_per_segment_offset.const_data_ptr<index_t>(),
284       num_of_segments_ptr, num_of_partial_segments_ptr);
285     C10_CUDA_KERNEL_LAUNCH_CHECK();
286 
287     auto max_partial_segment = numel / NROWS_PER_THREAD + max_segments;
288 
289     // Now we can compute the start position of each partial-segment
290     // Unit: index in `sorted_indices` and `orig_indices`
291     auto partial_segment_offset = at::empty({max_partial_segment}, orig_indices.options());
292     {
293       krn_partial_segment_offset<<<ceil_div(max_segments, 32), 32, 0, stream>>> (
294               partial_segment_offset.mutable_data_ptr<index_t>(),
295               partials_per_segment.const_data_ptr<index_t>(),
296               partials_per_segment_offset.const_data_ptr<index_t>(),
297               segment_offsets.const_data_ptr<index_t>(),
298               num_of_segments_ptr);
299       C10_CUDA_KERNEL_LAUNCH_CHECK();
300     }
301 
302     const int warp_size = at::cuda::warp_size();
303     const int stride_warped = ceil_div(stride, warp_size)*warp_size;
304     const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
305     const int grid = ceil_div(max_partial_segment*stride_warped, block);
306 
307     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
308       grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
309         // For numerical stability, the dtype of `grad_weight_per_segment`
310         // should match `acc_type`
311         using partial_weight_t = acc_type<scalar_t, true>;
312         TensorOptions op;
313         if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) {
314             op = grad.options().dtype(at::kFloat);
315         } else {
316             op = grad.options();
317         }
318         auto grad_weight_per_segment = at::empty({max_partial_segment, stride}, op);
319         // Compute the sum of each partial-segment and handle bags
320         if (offset2bag.defined()) {
321               compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
322                 orig_indices.const_data_ptr<index_t>(),
323                 grad.const_data_ptr<scalar_t>(),
324                 offset2bag.const_data_ptr<index_t>(),
325                 count.defined() ? count.const_data_ptr<index_t>() : nullptr, numel, stride,
326                 mode_mean, bag_size.const_data_ptr<index_t>(),
327                 per_sample_weights.defined() ? per_sample_weights.const_data_ptr<scalar_t>() : NULL,
328                 per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
329                 partial_segment_offset.const_data_ptr<index_t>(),
330                 num_of_partial_segments_ptr, grad_weight_per_segment.mutable_data_ptr<partial_weight_t>(),
331                 stride_warped);
332               C10_CUDA_KERNEL_LAUNCH_CHECK();
333         } else {
334               compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
335                 orig_indices.const_data_ptr<index_t>(),
336                 grad.const_data_ptr<scalar_t>(),
337                 count.defined() ? count.const_data_ptr<index_t>() : nullptr,
338                 numel, stride,
339                 partial_segment_offset.const_data_ptr<index_t>(),
340                 num_of_partial_segments_ptr,
341                 grad_weight_per_segment.mutable_data_ptr<partial_weight_t>(),
342                 stride_warped);
343               C10_CUDA_KERNEL_LAUNCH_CHECK();
344         }
345 
346         // Finally, we sum all the partial-sums and scatter them
347         // into `grad_weight`.
348         const int grid2 = ceil_div(max_segments*stride_warped, block);
349             sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
350               sorted_indices.const_data_ptr<index_t>(),
351               grad_weight.mutable_data_ptr<scalar_t>(),
352               stride,
353               segment_offsets.const_data_ptr<index_t>(),
354               num_of_segments_ptr, grad_weight_per_segment.const_data_ptr<partial_weight_t>(),
355               partials_per_segment_offset.const_data_ptr<index_t>(),
356               num_of_partial_segments_ptr,
357               padding_idx,
358               stride_warped);
359         C10_CUDA_KERNEL_LAUNCH_CHECK();
360     });
361   });
362   return grad_weight;
363 }
364 
365 } // namespace at::native
366