xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/EmbeddingBag.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/cuda/Atomic.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/DeviceUtils.cuh>
9 #include <ATen/native/EmbeddingBag.h>
10 #include <ATen/TensorUtils.h>
11 
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/arange.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_like.h>
19 #include <ATen/ops/zeros.h>
20 #include <ATen/ops/_embedding_bag_native.h>
21 #include <ATen/ops/_embedding_bag_forward_only_native.h>
22 #include <ATen/ops/_embedding_bag_dense_backward_native.h>
23 #include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
24 #endif
25 
26 #include <ATen/cuda/cub.cuh>
27 #include <ATen/native/cuda/SortingCommon.cuh>
28 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
29 #include <ATen/native/cuda/KernelUtils.cuh>
30 #include <ATen/native/cuda/block_reduce.cuh>
31 
32 #include <c10/macros/Macros.h>
33 
34 #if CUB_SUPPORTS_SCAN_BY_KEY()
35 #include <thrust/iterator/reverse_iterator.h>
36 #endif
37 
38 namespace at::native {
39 
40 #if !CUB_SUPPORTS_SCAN_BY_KEY()
41 template<typename index_t>
42 void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
43 #endif
44 
45 namespace {
46 
promoteIndicesAndOffsets(const Tensor & indices,const Tensor & offsets)47 std::pair<Tensor, Tensor> promoteIndicesAndOffsets(
48     const Tensor& indices,
49     const Tensor& offsets) {
50   const auto commonType =
51       promoteTypes(offsets.scalar_type(), indices.scalar_type());
52   return {
53       indices.scalar_type() == commonType ? indices
54                                           : indices.toType(commonType),
55       offsets.scalar_type() == commonType ? offsets
56                                           : offsets.toType(commonType)};
57 }
58 
59 // This kernel assumes that all input tensors except `weight` and
60 // per_sample_weights are contiguous.
61 template <typename scalar_t, typename index_t>
EmbeddingBag_updateOutputKernel_max(const index_t * input,const index_t * offsets,const scalar_t * weight,scalar_t * output,index_t * offset2bag,int64_t numIndices,int64_t numBags,int64_t featureSize,int64_t weight_stride0,int64_t weight_stride1,index_t * bag_size,index_t * max_indices,index_t padding_idx,int64_t numRows)62 __global__ void EmbeddingBag_updateOutputKernel_max(
63     const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output,
64     index_t *offset2bag, int64_t numIndices, int64_t numBags,
65     int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
66     index_t *bag_size, index_t *max_indices,
67     index_t padding_idx, int64_t numRows) {
68 
69   // the strategy here is that each bag x feature is handled by a single thread
70 
71   int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x);
72   int64_t numChunks = numBags * chunksPerBag;
73   int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
74   int64_t chunkStride = gridDim.x * blockDim.y;
75 
76   for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
77     int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
78     if (featureDim < featureSize) {
79       int64_t bag = chunk / chunksPerBag;
80       const scalar_t *weightFeat = weight + featureDim * weight_stride1;
81       int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
82       int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
83       CUDA_KERNEL_ASSERT(end >= begin);
84       scalar_t weightFeatMax = 0;
85       int64_t bag_size_ = 0;
86       int64_t maxWord = -1;
87       for (int64_t emb = begin; emb < end; emb++) {
88         bool pad = (input[emb] == padding_idx);
89         CUDA_KERNEL_ASSERT(input[emb] < numRows);
90         const int64_t weightRow = input[emb] * weight_stride0;
91         scalar_t weightValue = weightFeat[weightRow];
92         if (bag_size_ == 0 || weightValue > weightFeatMax) {
93           weightFeatMax = pad ? weightFeatMax : weightValue;
94           maxWord = pad ? maxWord : input[emb];
95         }
96         bag_size_ += pad ? 0 : 1;
97 
98         if (featureDim == 0) {
99           offset2bag[emb] = bag;
100         }
101       }
102       bag_size[bag] = bag_size_;
103       max_indices[bag * featureSize + featureDim] = maxWord;
104       output[bag * featureSize + featureDim] = weightFeatMax;
105     }
106   }
107 }
108 
109 // This kernel assumes that all input tensors except `weight` and
110 // per_sample_weights are contiguous.
111 template <typename scalar_t, typename index_t>
EmbeddingBag_updateOutputKernel_sum_mean(const index_t * input,const index_t * offsets,const scalar_t * weight,scalar_t * output,index_t * offset2bag,int64_t numIndices,int64_t numBags,int64_t featureSize,int64_t weight_stride0,int64_t weight_stride1,int mode,index_t * bag_size,const scalar_t * per_sample_weights,int64_t per_sample_weights_stride,index_t padding_idx,int64_t numRows)112 __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
113     const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output,
114     index_t *offset2bag, int64_t numIndices, int64_t numBags,
115     int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
116     int mode, index_t *bag_size,
117     const scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
118     index_t padding_idx, int64_t numRows) {
119 
120   // the strategy here is that each bag x feature is handled by a single thread
121 
122   using accscalar_t = acc_type<scalar_t, true>;
123   int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x);
124   int64_t numChunks = numBags * chunksPerBag;
125   int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
126   int64_t chunkStride = gridDim.x * blockDim.y;
127 
128   for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
129     int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
130     if (featureDim < featureSize) {
131       int64_t bag = chunk / chunksPerBag;
132       const scalar_t *weightFeat = weight + featureDim * weight_stride1;
133       int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
134       int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
135       CUDA_KERNEL_ASSERT(end >= begin);
136       accscalar_t weightFeatSum = 0;
137       int64_t bag_size_ = 0;
138       for (int64_t emb = begin; emb < end; emb++) {
139         bool pad = (input[emb] == padding_idx);
140         CUDA_KERNEL_ASSERT(input[emb] < numRows);
141         const int64_t weightRow = input[emb] * weight_stride0;
142         scalar_t weightValue = weightFeat[weightRow];
143         weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
144         if (per_sample_weights) {
145           accscalar_t scaleWeightBy = static_cast<accscalar_t>(
146               per_sample_weights[emb * per_sample_weights_stride]);
147           weightFeatSum += scaleWeightBy * static_cast<accscalar_t>(weightValue);
148         } else {
149           weightFeatSum += static_cast<accscalar_t>(weightValue);
150         }
151         bag_size_ += pad ? 0 : 1;
152 
153         if (featureDim == 0) {
154           offset2bag[emb] = bag;
155         }
156       }
157       if (mode == static_cast<int64_t>(EmbeddingBagMode::MEAN)) {
158         if (bag_size_ != 0) {
159           weightFeatSum = weightFeatSum / static_cast<accscalar_t>(bag_size_);
160         }
161       }
162       bag_size[bag] = bag_size_;
163       output[bag * featureSize + featureDim] = static_cast<scalar_t>(weightFeatSum);
164     }
165   }
166 }
167 
embedding_bag_backward_cuda_sum_avg(const Tensor & grad,const Tensor & indices_,const Tensor & offset2bag,const Tensor & bag_size,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,const Tensor & per_sample_weights,int64_t padding_idx)168 Tensor embedding_bag_backward_cuda_sum_avg(
169                                    const Tensor &grad,
170                                    const Tensor &indices_,
171                                    const Tensor &offset2bag,
172                                    const Tensor &bag_size,
173                                    int64_t num_weights,
174                                    bool scale_grad_by_freq, int64_t mode,
175                                    const Tensor& per_sample_weights,
176                                    int64_t padding_idx) {
177   auto indices = indices_.contiguous();
178 
179   ptrdiff_t num_indices = indices.numel();
180 
181   if (num_indices == 0) {
182     // all empty bags
183     return at::zeros({num_weights, grad.size(1)}, grad.options());
184   }
185 
186   auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
187   auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
188   Tensor count;
189 
190   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
191     auto range = at::arange(num_indices, indices.options());
192     // int64_t nbits = cuda::cub::get_num_bits(num_weights);
193     cuda::cub::radix_sort_pairs(
194       indices.const_data_ptr<index_t>(), sorted_indices.mutable_data_ptr<index_t>(),
195       range.const_data_ptr<index_t>(), orig_indices.mutable_data_ptr<index_t>(),
196       num_indices, false/*, 0, nbits*/);
197   });
198 
199   if (scale_grad_by_freq) {
200     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
201 #if CUB_SUPPORTS_SCAN_BY_KEY()
202     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
203       cudaStream_t stream = at::cuda::getCurrentCUDAStream();
204 
205       // Compute an increasing sequence per unique item in sortedIndices:
206       // sorted: 2 5 5 5 7 7 8 9 9
207       //  count: 1 1 2 3 1 2 1 1 2
208       auto sorted_data = sorted_indices.const_data_ptr<index_t>();
209       auto count_data = count.mutable_data_ptr<index_t>();
210       cuda::cub::inclusive_sum_by_key(
211         sorted_data,
212         at_cuda_detail::cub::ConstantInputIterator<index_t>(1),
213         count_data,
214         num_indices
215       );
216 
217       // Take the maximum of each count per unique key in reverse:
218       // sorted: 2 5 5 5 7 7 8 9 9
219       //  count: 1 3 3 3 2 2 1 2 2
220       cuda::cub::inclusive_scan_by_key(
221         thrust::make_reverse_iterator(sorted_data + num_indices),
222         thrust::make_reverse_iterator(count_data + num_indices),
223         thrust::make_reverse_iterator(count_data + num_indices),
224         at_cuda_detail::cub::Max(),
225         num_indices
226       );
227     });
228 #else
229     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
230       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
231     });
232 #endif
233   }
234   return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
235       count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,
236       bag_size, per_sample_weights);
237 }
238 
239 template <typename scalar_t, typename index_t>
EmbeddingBag_accGradParametersKernel_max(const index_t * max_indices,const scalar_t * gradOutput,scalar_t * gradWeight,int64_t stride,int64_t numBags,index_t padding_idx,const index_t numel)240 __global__ void EmbeddingBag_accGradParametersKernel_max(
241     const index_t *max_indices, const scalar_t *gradOutput,
242     scalar_t *gradWeight, int64_t stride, int64_t numBags,
243     index_t padding_idx, const index_t numel) {
244 
245   using accscalar_t = acc_type<scalar_t, true>;
246 
247   int64_t chunksPerBag = ceil_div(stride, (int64_t)blockDim.x);
248   int64_t numChunks = numBags * chunksPerBag;
249   int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
250   int64_t chunkStride = gridDim.x * blockDim.y;
251 
252   for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
253     int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
254     if (featureDim < stride) {
255       int64_t bag = chunk / chunksPerBag;
256 
257       index_t word_idx = max_indices[bag * stride + featureDim];
258       if (word_idx >= 0 && word_idx != padding_idx) {
259         // If bag is empty, we have max_indices[idx] set to -1 in forward.
260         fastAtomicAdd(
261             gradWeight, static_cast<index_t>(word_idx * stride + featureDim),
262             numel, gradOutput[bag * stride + featureDim], true);
263       }
264     }
265   }
266 }
267 
embedding_bag_backward_cuda_max(const Tensor & grad,const Tensor & max_indices,int64_t num_weights,int64_t padding_idx)268 Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
269                                    const Tensor &max_indices,
270                                    int64_t num_weights,
271                                    int64_t padding_idx) {
272   // See Note [Writing Nondeterministic Operations]
273   // Nondeterministic because of atomicAdd usage
274   globalContext().alertNotDeterministic("embedding_bag_backward_cuda_max");
275 
276   auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
277 
278   int64_t stride = grad_weight.stride(0);
279 
280   int64_t numBags = grad.size(0);
281 
282   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
283 
284 #if defined(USE_ROCM)
285   dim3 block = dim3(64, 4);
286 #else
287   dim3 block = dim3(32, 8);
288 #endif
289   int grid = 1024;
290 
291   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
292       grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] {
293         AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () {
294           EmbeddingBag_accGradParametersKernel_max<
295               scalar_t, index_t><<<grid, block, 0, stream>>>(
296               max_indices.const_data_ptr<index_t>(), grad.const_data_ptr<scalar_t>(),
297               grad_weight.mutable_data_ptr<scalar_t>(), stride, numBags,
298               padding_idx, grad_weight.numel());
299         C10_CUDA_KERNEL_LAUNCH_CHECK();
300       });
301   });
302 
303   return grad_weight;
304 }
305 }
306 
307 // Assumes all input tensors are contiguous.
308 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
309 std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_forward_only_cuda(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,int64_t padding_idx)310 _embedding_bag_forward_only_cuda(const Tensor &weight, const Tensor &indices,
311                    const Tensor &offsets, const bool scale_grad_by_freq,
312                    const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
313                    bool include_last_offset, int64_t padding_idx) {
314   // See [Note: hacky wrapper removal for optional tensor]
315   c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
316   const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
317 
318   return _embedding_bag_cuda(
319       weight,
320       indices,
321       offsets,
322       scale_grad_by_freq,
323       mode,
324       sparse,
325       per_sample_weights,
326       include_last_offset,
327       padding_idx);
328 }
329 
330 // Assumes all input tensors are contiguous.
331 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
332 std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cuda(const Tensor & weight,const Tensor & indices_,const Tensor & offsets_,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,int64_t padding_idx)333 _embedding_bag_cuda(const Tensor &weight, const Tensor &indices_,
334                    const Tensor &offsets_, const bool scale_grad_by_freq,
335                    const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
336                    bool include_last_offset, int64_t padding_idx) {
337   TORCH_CHECK(indices_.dim() == 1 || indices_.dim() == 2,
338       "input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
339       indices_.dim());
340   if (indices_.dim() == 1) {
341     TORCH_CHECK(offsets_.dim() == 1,
342         "offsets has to be a 1D Tensor, but got Tensor of dimension ",
343         offsets_.dim());
344   }
345   TORCH_CHECK(weight.dim() == 2,
346       "weight has to be a 2D Tensor, but got Tensor of dimension ",
347       weight.dim());
348   // See [Note: hacky wrapper removal for optional tensor]
349   c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
350   const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
351 
352   Tensor indices, offsets;
353   std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
354   auto indices_arg = TensorArg(indices, "indices", 1);
355   checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt});
356   auto offsets_arg = TensorArg(offsets, "offsets", 1);
357   checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt});
358   checkSameType("embedding_bag_cuda", indices_arg, offsets_arg);
359   auto weight_arg = TensorArg(weight, "weight", 1);
360   checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
361   checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
362 
363   int64_t numIndices = indices.size(0);
364   int64_t numBags = offsets.size(0);
365   if (include_last_offset) {
366     // Check https://github.com/pytorch/pytorch/issues/29019
367     // We plan to add one more element in offsets, which is equal to the size of
368     // indices. Currently for cuda devices, we still use the legacy
369     // implementation even this flag is enabled.
370     TORCH_CHECK(
371         numBags >= 1, "include_last_offset: numBags should be at least 1");
372     numBags -= 1;
373   }
374   int64_t featureSize = weight.size(1);
375 
376   auto bag_size = at::empty(offsets.sizes(), indices.options());
377   auto offset2bag =
378       at::empty({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0]
379 
380   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
381 
382   auto output = at::empty({numBags, featureSize}, weight.options());
383 
384   Tensor max_indices;
385 
386   if (mode == EmbeddingBagMode::MAX) {
387     max_indices = at::empty({numBags, featureSize}, indices.options());
388   } else {
389     // No need to allocate if we aren't doing a backwards pass
390     max_indices = at::empty({0}, indices.options());
391   }
392 
393 #if defined(USE_ROCM)
394   dim3 block = dim3(64, 4);
395 #else
396   dim3 block = dim3(32, 8);
397 #endif
398   int grid = 1024;
399   AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] {
400     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () {
401       if (mode == EmbeddingBagMode::MAX) {
402         EmbeddingBag_updateOutputKernel_max<scalar_t, index_t><<<grid, block, 0, stream>>>(
403             indices.const_data_ptr<index_t>(), offsets.const_data_ptr<index_t>(),
404             weight.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(),
405             offset2bag.mutable_data_ptr<index_t>(), numIndices, numBags, featureSize,
406             weight.stride(0), weight.stride(1), bag_size.mutable_data_ptr<index_t>(),
407             max_indices.mutable_data_ptr<index_t>(),
408             padding_idx, weight.size(0));
409         C10_CUDA_KERNEL_LAUNCH_CHECK();
410       } else {
411         EmbeddingBag_updateOutputKernel_sum_mean<scalar_t, index_t><<<grid, block, 0, stream>>>(
412             indices.const_data_ptr<index_t>(), offsets.const_data_ptr<index_t>(),
413             weight.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(),
414             offset2bag.mutable_data_ptr<index_t>(), numIndices, numBags, featureSize,
415             weight.stride(0), weight.stride(1), mode, bag_size.mutable_data_ptr<index_t>(),
416             per_sample_weights.defined() ? per_sample_weights.const_data_ptr<scalar_t>() : NULL,
417             per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
418             padding_idx, weight.size(0));
419         C10_CUDA_KERNEL_LAUNCH_CHECK();
420       }
421     });
422   });
423 
424   return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
425 }
426 
_embedding_bag_dense_backward_cuda(const Tensor & grad_,const Tensor & indices,const Tensor & offset2bag,const Tensor & bag_size_,const Tensor & max_indices,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,const std::optional<Tensor> & per_sample_weights_opt,int64_t padding_idx)427 Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &indices,
428                                    const Tensor &offset2bag,
429                                    const Tensor &bag_size_,
430                                    const Tensor &max_indices,
431                                    int64_t num_weights,
432                                    bool scale_grad_by_freq, int64_t mode, const std::optional<Tensor>& per_sample_weights_opt,
433                                    int64_t padding_idx) {
434   // See [Note: hacky wrapper removal for optional tensor]
435   c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
436   const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
437 
438   // indices, offsets and offset2bag are assumed having correct dtypes and
439   // contiguous here due to the checks in _embedding_bag_backward in
440   // EmbeddingBag.cpp.
441   // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
442   // for more details.
443 
444   Tensor grad = grad_.contiguous();
445   auto indices_arg = TensorArg(indices, "indices", 1);
446   auto grad_arg = TensorArg(grad, "grad", 1);
447   checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg);
448 
449 
450   switch (static_cast<EmbeddingBagMode>(mode)) {
451     case EmbeddingBagMode::SUM:
452     case EmbeddingBagMode::MEAN:
453       if (mode == EmbeddingBagMode::MEAN)
454         AT_ASSERT(!per_sample_weights.defined());
455       return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag,
456               bag_size_, num_weights, scale_grad_by_freq, mode,
457               per_sample_weights, padding_idx);
458 
459     case EmbeddingBagMode::MAX:
460       AT_ASSERT(!per_sample_weights.defined());
461       return embedding_bag_backward_cuda_max(grad, max_indices, num_weights,
462               padding_idx);
463 
464     default:
465       AT_ERROR(
466           "Unknown mode for embedding_bag_backward_cuda ", mode);
467   }
468 }
469 
470 template <typename scalar_t, typename index_t>
_embedding_bag_per_sample_weights_backward_kernel(const scalar_t * grad,int64_t grad_stride0,int64_t grad_stride1,const scalar_t * weight,int64_t weight_stride0,int64_t weight_stride1,const index_t * indices,const index_t * offset2bag,int64_t num_samples,int64_t embedding_features,scalar_t * output,index_t padding_idx)471 __global__ static void _embedding_bag_per_sample_weights_backward_kernel(
472     const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1,
473     const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1,
474     const index_t* indices,  // contiguous
475     const index_t* offset2bag,  // contiguous
476     int64_t num_samples,
477     int64_t embedding_features,
478     scalar_t* output,
479     index_t padding_idx) {
480   using accscalar_t = acc_type<scalar_t, true>;
481   const int idx = threadIdx.x + blockIdx.x * blockDim.x;
482   const int warp = idx / C10_WARP_SIZE;
483   const int thread_in_warp = idx % C10_WARP_SIZE;
484   const int num_warps = blockDim.x * gridDim.x / C10_WARP_SIZE;
485 
486   // Each warp is responsible for the accumulation of one sample.
487   // This involves doing one dot product between grad[bag_idx] and weight[embedding_idx].
488   for (int sample_idx = warp; sample_idx < num_samples; sample_idx += num_warps) {
489     accscalar_t result = 0.;
490     const int bag_idx = (int)offset2bag[sample_idx];
491     const int embedding_idx = (int)indices[sample_idx];
492     if (embedding_idx != padding_idx) {
493       for (int feature_idx = thread_in_warp; feature_idx < embedding_features;
494           feature_idx += C10_WARP_SIZE) {
495         result +=
496             grad[grad_stride0 * bag_idx + grad_stride1 * feature_idx] *
497             weight[weight_stride0 * embedding_idx + weight_stride1 * feature_idx];
498       }
499     }
500     result = cuda_utils::WarpReduceSum<accscalar_t>(result);
501     if (thread_in_warp == 0) {
502       output[sample_idx] = result;
503     }
504   }
505 }
506 
_embedding_bag_per_sample_weights_backward_cuda(const Tensor & grad,const Tensor & weight,const Tensor & indices_,const Tensor & offsets_,const Tensor & offset2bag,int64_t mode,int64_t padding_idx)507 Tensor _embedding_bag_per_sample_weights_backward_cuda(
508     const Tensor& grad,
509     const Tensor& weight,  // NB: embedding table, not per_sample_weights
510     const Tensor& indices_,
511     const Tensor& offsets_,
512     const Tensor& offset2bag,
513     int64_t mode,
514     int64_t padding_idx) {
515   TORCH_CHECK(
516       mode == EmbeddingBagMode::SUM,
517       "embedding_bag_backward: per_sample_weights only supported for mode='sum'");
518 
519   AT_ASSERT(grad.dim() == 2);
520   auto embedding_features = grad.size(1);
521 
522   Tensor indices, offsets;
523   std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
524   AT_ASSERT(indices.dim() == 1);
525   auto num_samples = indices.size(0);
526 
527   AT_ASSERT(weight.dim() == 2);
528   AT_ASSERT(weight.size(1) == embedding_features);
529 
530   const int threads_per_block = 512;
531   const int warps_per_block = threads_per_block / at::cuda::warp_size();
532 
533   dim3 block(threads_per_block);
534   dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);
535 
536   auto output = at::empty({num_samples}, grad.options());
537 
538   // Early return when there is no samples in the batch. This saves unnecessary kernel
539   // launch, but also prevents cudaGetLastError() to complain about invalid launch args
540   if (num_samples == 0) {
541     return output;
542   }
543 
544   AT_DISPATCH_FLOATING_TYPES_AND_HALF(
545     grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
546       AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
547         _embedding_bag_per_sample_weights_backward_kernel<scalar_t, index_t>
548           <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
549             grad.const_data_ptr<scalar_t>(), grad.stride(0), grad.stride(1),
550             weight.const_data_ptr<scalar_t>(), weight.stride(0), weight.stride(1),
551             indices.const_data_ptr<index_t>(),
552             offset2bag.const_data_ptr<index_t>(),
553             num_samples,
554             embedding_features,
555             output.mutable_data_ptr<scalar_t>(),
556             padding_idx);
557         C10_CUDA_KERNEL_LAUNCH_CHECK();
558       });
559     }
560   );
561   return output;
562 }
563 
564 } // namespace at::native
565