1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/ceil_div.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/cuda/Atomic.cuh>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/DeviceUtils.cuh>
9 #include <ATen/native/EmbeddingBag.h>
10 #include <ATen/TensorUtils.h>
11
12 #ifndef AT_PER_OPERATOR_HEADERS
13 #include <ATen/Functions.h>
14 #include <ATen/NativeFunctions.h>
15 #else
16 #include <ATen/ops/arange.h>
17 #include <ATen/ops/empty.h>
18 #include <ATen/ops/empty_like.h>
19 #include <ATen/ops/zeros.h>
20 #include <ATen/ops/_embedding_bag_native.h>
21 #include <ATen/ops/_embedding_bag_forward_only_native.h>
22 #include <ATen/ops/_embedding_bag_dense_backward_native.h>
23 #include <ATen/ops/_embedding_bag_per_sample_weights_backward_native.h>
24 #endif
25
26 #include <ATen/cuda/cub.cuh>
27 #include <ATen/native/cuda/SortingCommon.cuh>
28 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
29 #include <ATen/native/cuda/KernelUtils.cuh>
30 #include <ATen/native/cuda/block_reduce.cuh>
31
32 #include <c10/macros/Macros.h>
33
34 #if CUB_SUPPORTS_SCAN_BY_KEY()
35 #include <thrust/iterator/reverse_iterator.h>
36 #endif
37
38 namespace at::native {
39
40 #if !CUB_SUPPORTS_SCAN_BY_KEY()
41 template<typename index_t>
42 void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
43 #endif
44
45 namespace {
46
promoteIndicesAndOffsets(const Tensor & indices,const Tensor & offsets)47 std::pair<Tensor, Tensor> promoteIndicesAndOffsets(
48 const Tensor& indices,
49 const Tensor& offsets) {
50 const auto commonType =
51 promoteTypes(offsets.scalar_type(), indices.scalar_type());
52 return {
53 indices.scalar_type() == commonType ? indices
54 : indices.toType(commonType),
55 offsets.scalar_type() == commonType ? offsets
56 : offsets.toType(commonType)};
57 }
58
59 // This kernel assumes that all input tensors except `weight` and
60 // per_sample_weights are contiguous.
61 template <typename scalar_t, typename index_t>
EmbeddingBag_updateOutputKernel_max(const index_t * input,const index_t * offsets,const scalar_t * weight,scalar_t * output,index_t * offset2bag,int64_t numIndices,int64_t numBags,int64_t featureSize,int64_t weight_stride0,int64_t weight_stride1,index_t * bag_size,index_t * max_indices,index_t padding_idx,int64_t numRows)62 __global__ void EmbeddingBag_updateOutputKernel_max(
63 const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output,
64 index_t *offset2bag, int64_t numIndices, int64_t numBags,
65 int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
66 index_t *bag_size, index_t *max_indices,
67 index_t padding_idx, int64_t numRows) {
68
69 // the strategy here is that each bag x feature is handled by a single thread
70
71 int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x);
72 int64_t numChunks = numBags * chunksPerBag;
73 int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
74 int64_t chunkStride = gridDim.x * blockDim.y;
75
76 for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
77 int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
78 if (featureDim < featureSize) {
79 int64_t bag = chunk / chunksPerBag;
80 const scalar_t *weightFeat = weight + featureDim * weight_stride1;
81 int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
82 int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
83 CUDA_KERNEL_ASSERT(end >= begin);
84 scalar_t weightFeatMax = 0;
85 int64_t bag_size_ = 0;
86 int64_t maxWord = -1;
87 for (int64_t emb = begin; emb < end; emb++) {
88 bool pad = (input[emb] == padding_idx);
89 CUDA_KERNEL_ASSERT(input[emb] < numRows);
90 const int64_t weightRow = input[emb] * weight_stride0;
91 scalar_t weightValue = weightFeat[weightRow];
92 if (bag_size_ == 0 || weightValue > weightFeatMax) {
93 weightFeatMax = pad ? weightFeatMax : weightValue;
94 maxWord = pad ? maxWord : input[emb];
95 }
96 bag_size_ += pad ? 0 : 1;
97
98 if (featureDim == 0) {
99 offset2bag[emb] = bag;
100 }
101 }
102 bag_size[bag] = bag_size_;
103 max_indices[bag * featureSize + featureDim] = maxWord;
104 output[bag * featureSize + featureDim] = weightFeatMax;
105 }
106 }
107 }
108
109 // This kernel assumes that all input tensors except `weight` and
110 // per_sample_weights are contiguous.
111 template <typename scalar_t, typename index_t>
EmbeddingBag_updateOutputKernel_sum_mean(const index_t * input,const index_t * offsets,const scalar_t * weight,scalar_t * output,index_t * offset2bag,int64_t numIndices,int64_t numBags,int64_t featureSize,int64_t weight_stride0,int64_t weight_stride1,int mode,index_t * bag_size,const scalar_t * per_sample_weights,int64_t per_sample_weights_stride,index_t padding_idx,int64_t numRows)112 __global__ void EmbeddingBag_updateOutputKernel_sum_mean(
113 const index_t *input, const index_t *offsets, const scalar_t *weight, scalar_t *output,
114 index_t *offset2bag, int64_t numIndices, int64_t numBags,
115 int64_t featureSize, int64_t weight_stride0, int64_t weight_stride1,
116 int mode, index_t *bag_size,
117 const scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
118 index_t padding_idx, int64_t numRows) {
119
120 // the strategy here is that each bag x feature is handled by a single thread
121
122 using accscalar_t = acc_type<scalar_t, true>;
123 int64_t chunksPerBag = ceil_div(featureSize, (int64_t)blockDim.x);
124 int64_t numChunks = numBags * chunksPerBag;
125 int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
126 int64_t chunkStride = gridDim.x * blockDim.y;
127
128 for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
129 int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
130 if (featureDim < featureSize) {
131 int64_t bag = chunk / chunksPerBag;
132 const scalar_t *weightFeat = weight + featureDim * weight_stride1;
133 int64_t begin = bag == 0 ? 0 : offsets[bag]; // forces first offset to be 0 instead of asserting on it
134 int64_t end = (bag < numBags - 1) ? (offsets[bag + 1]) : numIndices;
135 CUDA_KERNEL_ASSERT(end >= begin);
136 accscalar_t weightFeatSum = 0;
137 int64_t bag_size_ = 0;
138 for (int64_t emb = begin; emb < end; emb++) {
139 bool pad = (input[emb] == padding_idx);
140 CUDA_KERNEL_ASSERT(input[emb] < numRows);
141 const int64_t weightRow = input[emb] * weight_stride0;
142 scalar_t weightValue = weightFeat[weightRow];
143 weightValue = pad ? static_cast<scalar_t>(0) : weightValue;
144 if (per_sample_weights) {
145 accscalar_t scaleWeightBy = static_cast<accscalar_t>(
146 per_sample_weights[emb * per_sample_weights_stride]);
147 weightFeatSum += scaleWeightBy * static_cast<accscalar_t>(weightValue);
148 } else {
149 weightFeatSum += static_cast<accscalar_t>(weightValue);
150 }
151 bag_size_ += pad ? 0 : 1;
152
153 if (featureDim == 0) {
154 offset2bag[emb] = bag;
155 }
156 }
157 if (mode == static_cast<int64_t>(EmbeddingBagMode::MEAN)) {
158 if (bag_size_ != 0) {
159 weightFeatSum = weightFeatSum / static_cast<accscalar_t>(bag_size_);
160 }
161 }
162 bag_size[bag] = bag_size_;
163 output[bag * featureSize + featureDim] = static_cast<scalar_t>(weightFeatSum);
164 }
165 }
166 }
167
embedding_bag_backward_cuda_sum_avg(const Tensor & grad,const Tensor & indices_,const Tensor & offset2bag,const Tensor & bag_size,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,const Tensor & per_sample_weights,int64_t padding_idx)168 Tensor embedding_bag_backward_cuda_sum_avg(
169 const Tensor &grad,
170 const Tensor &indices_,
171 const Tensor &offset2bag,
172 const Tensor &bag_size,
173 int64_t num_weights,
174 bool scale_grad_by_freq, int64_t mode,
175 const Tensor& per_sample_weights,
176 int64_t padding_idx) {
177 auto indices = indices_.contiguous();
178
179 ptrdiff_t num_indices = indices.numel();
180
181 if (num_indices == 0) {
182 // all empty bags
183 return at::zeros({num_weights, grad.size(1)}, grad.options());
184 }
185
186 auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
187 auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
188 Tensor count;
189
190 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
191 auto range = at::arange(num_indices, indices.options());
192 // int64_t nbits = cuda::cub::get_num_bits(num_weights);
193 cuda::cub::radix_sort_pairs(
194 indices.const_data_ptr<index_t>(), sorted_indices.mutable_data_ptr<index_t>(),
195 range.const_data_ptr<index_t>(), orig_indices.mutable_data_ptr<index_t>(),
196 num_indices, false/*, 0, nbits*/);
197 });
198
199 if (scale_grad_by_freq) {
200 count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
201 #if CUB_SUPPORTS_SCAN_BY_KEY()
202 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
203 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
204
205 // Compute an increasing sequence per unique item in sortedIndices:
206 // sorted: 2 5 5 5 7 7 8 9 9
207 // count: 1 1 2 3 1 2 1 1 2
208 auto sorted_data = sorted_indices.const_data_ptr<index_t>();
209 auto count_data = count.mutable_data_ptr<index_t>();
210 cuda::cub::inclusive_sum_by_key(
211 sorted_data,
212 at_cuda_detail::cub::ConstantInputIterator<index_t>(1),
213 count_data,
214 num_indices
215 );
216
217 // Take the maximum of each count per unique key in reverse:
218 // sorted: 2 5 5 5 7 7 8 9 9
219 // count: 1 3 3 3 2 2 1 2 2
220 cuda::cub::inclusive_scan_by_key(
221 thrust::make_reverse_iterator(sorted_data + num_indices),
222 thrust::make_reverse_iterator(count_data + num_indices),
223 thrust::make_reverse_iterator(count_data + num_indices),
224 at_cuda_detail::cub::Max(),
225 num_indices
226 );
227 });
228 #else
229 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_backward_cuda_sum_avg", [&] () {
230 embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
231 });
232 #endif
233 }
234 return embedding_backward_cuda_kernel(grad, orig_indices, sorted_indices,
235 count, num_weights, padding_idx, mode == EmbeddingBagMode::MEAN, offset2bag,
236 bag_size, per_sample_weights);
237 }
238
239 template <typename scalar_t, typename index_t>
EmbeddingBag_accGradParametersKernel_max(const index_t * max_indices,const scalar_t * gradOutput,scalar_t * gradWeight,int64_t stride,int64_t numBags,index_t padding_idx,const index_t numel)240 __global__ void EmbeddingBag_accGradParametersKernel_max(
241 const index_t *max_indices, const scalar_t *gradOutput,
242 scalar_t *gradWeight, int64_t stride, int64_t numBags,
243 index_t padding_idx, const index_t numel) {
244
245 using accscalar_t = acc_type<scalar_t, true>;
246
247 int64_t chunksPerBag = ceil_div(stride, (int64_t)blockDim.x);
248 int64_t numChunks = numBags * chunksPerBag;
249 int64_t chunkOffset = blockIdx.x * blockDim.y + threadIdx.y;
250 int64_t chunkStride = gridDim.x * blockDim.y;
251
252 for (int64_t chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) {
253 int64_t featureDim = (chunk % chunksPerBag) * blockDim.x + threadIdx.x;
254 if (featureDim < stride) {
255 int64_t bag = chunk / chunksPerBag;
256
257 index_t word_idx = max_indices[bag * stride + featureDim];
258 if (word_idx >= 0 && word_idx != padding_idx) {
259 // If bag is empty, we have max_indices[idx] set to -1 in forward.
260 fastAtomicAdd(
261 gradWeight, static_cast<index_t>(word_idx * stride + featureDim),
262 numel, gradOutput[bag * stride + featureDim], true);
263 }
264 }
265 }
266 }
267
embedding_bag_backward_cuda_max(const Tensor & grad,const Tensor & max_indices,int64_t num_weights,int64_t padding_idx)268 Tensor embedding_bag_backward_cuda_max(const Tensor &grad,
269 const Tensor &max_indices,
270 int64_t num_weights,
271 int64_t padding_idx) {
272 // See Note [Writing Nondeterministic Operations]
273 // Nondeterministic because of atomicAdd usage
274 globalContext().alertNotDeterministic("embedding_bag_backward_cuda_max");
275
276 auto grad_weight = at::zeros({num_weights, grad.size(1)}, grad.options());
277
278 int64_t stride = grad_weight.stride(0);
279
280 int64_t numBags = grad.size(0);
281
282 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
283
284 #if defined(USE_ROCM)
285 dim3 block = dim3(64, 4);
286 #else
287 dim3 block = dim3(32, 8);
288 #endif
289 int grid = 1024;
290
291 AT_DISPATCH_FLOATING_TYPES_AND_HALF(
292 grad.scalar_type(), "embedding_bag_backward_cuda_max", [&] {
293 AT_DISPATCH_INDEX_TYPES(max_indices.scalar_type(), "embedding_bag_backward_cuda_max", [&] () {
294 EmbeddingBag_accGradParametersKernel_max<
295 scalar_t, index_t><<<grid, block, 0, stream>>>(
296 max_indices.const_data_ptr<index_t>(), grad.const_data_ptr<scalar_t>(),
297 grad_weight.mutable_data_ptr<scalar_t>(), stride, numBags,
298 padding_idx, grad_weight.numel());
299 C10_CUDA_KERNEL_LAUNCH_CHECK();
300 });
301 });
302
303 return grad_weight;
304 }
305 }
306
307 // Assumes all input tensors are contiguous.
308 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
309 std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_forward_only_cuda(const Tensor & weight,const Tensor & indices,const Tensor & offsets,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,int64_t padding_idx)310 _embedding_bag_forward_only_cuda(const Tensor &weight, const Tensor &indices,
311 const Tensor &offsets, const bool scale_grad_by_freq,
312 const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
313 bool include_last_offset, int64_t padding_idx) {
314 // See [Note: hacky wrapper removal for optional tensor]
315 c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
316 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
317
318 return _embedding_bag_cuda(
319 weight,
320 indices,
321 offsets,
322 scale_grad_by_freq,
323 mode,
324 sparse,
325 per_sample_weights,
326 include_last_offset,
327 padding_idx);
328 }
329
330 // Assumes all input tensors are contiguous.
331 // See NOTE [ embedding_bag Native Functions ] in native_functions.yaml for details
332 std::tuple<Tensor, Tensor, Tensor, Tensor>
_embedding_bag_cuda(const Tensor & weight,const Tensor & indices_,const Tensor & offsets_,const bool scale_grad_by_freq,const int64_t mode,bool sparse,const std::optional<Tensor> & per_sample_weights_opt,bool include_last_offset,int64_t padding_idx)333 _embedding_bag_cuda(const Tensor &weight, const Tensor &indices_,
334 const Tensor &offsets_, const bool scale_grad_by_freq,
335 const int64_t mode, bool sparse, const std::optional<Tensor>& per_sample_weights_opt,
336 bool include_last_offset, int64_t padding_idx) {
337 TORCH_CHECK(indices_.dim() == 1 || indices_.dim() == 2,
338 "input has to be a 1D or 2D Tensor, but got Tensor of dimension ",
339 indices_.dim());
340 if (indices_.dim() == 1) {
341 TORCH_CHECK(offsets_.dim() == 1,
342 "offsets has to be a 1D Tensor, but got Tensor of dimension ",
343 offsets_.dim());
344 }
345 TORCH_CHECK(weight.dim() == 2,
346 "weight has to be a 2D Tensor, but got Tensor of dimension ",
347 weight.dim());
348 // See [Note: hacky wrapper removal for optional tensor]
349 c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
350 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
351
352 Tensor indices, offsets;
353 std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
354 auto indices_arg = TensorArg(indices, "indices", 1);
355 checkScalarTypes("embedding_bag_cuda", indices_arg, {kLong, kInt});
356 auto offsets_arg = TensorArg(offsets, "offsets", 1);
357 checkScalarTypes("embedding_bag_cuda", offsets_arg, {kLong, kInt});
358 checkSameType("embedding_bag_cuda", indices_arg, offsets_arg);
359 auto weight_arg = TensorArg(weight, "weight", 1);
360 checkSameGPU("embedding_bag_cuda", weight_arg, indices_arg);
361 checkSameGPU("embedding_bag_cuda", weight_arg, offsets_arg);
362
363 int64_t numIndices = indices.size(0);
364 int64_t numBags = offsets.size(0);
365 if (include_last_offset) {
366 // Check https://github.com/pytorch/pytorch/issues/29019
367 // We plan to add one more element in offsets, which is equal to the size of
368 // indices. Currently for cuda devices, we still use the legacy
369 // implementation even this flag is enabled.
370 TORCH_CHECK(
371 numBags >= 1, "include_last_offset: numBags should be at least 1");
372 numBags -= 1;
373 }
374 int64_t featureSize = weight.size(1);
375
376 auto bag_size = at::empty(offsets.sizes(), indices.options());
377 auto offset2bag =
378 at::empty({indices.size(0)}, indices.options()); // offset2bag = [0 0 0 0 0]
379
380 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
381
382 auto output = at::empty({numBags, featureSize}, weight.options());
383
384 Tensor max_indices;
385
386 if (mode == EmbeddingBagMode::MAX) {
387 max_indices = at::empty({numBags, featureSize}, indices.options());
388 } else {
389 // No need to allocate if we aren't doing a backwards pass
390 max_indices = at::empty({0}, indices.options());
391 }
392
393 #if defined(USE_ROCM)
394 dim3 block = dim3(64, 4);
395 #else
396 dim3 block = dim3(32, 8);
397 #endif
398 int grid = 1024;
399 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, weight.scalar_type(), "embedding_bag_cuda", [&] {
400 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_bag_cuda", [&] () {
401 if (mode == EmbeddingBagMode::MAX) {
402 EmbeddingBag_updateOutputKernel_max<scalar_t, index_t><<<grid, block, 0, stream>>>(
403 indices.const_data_ptr<index_t>(), offsets.const_data_ptr<index_t>(),
404 weight.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(),
405 offset2bag.mutable_data_ptr<index_t>(), numIndices, numBags, featureSize,
406 weight.stride(0), weight.stride(1), bag_size.mutable_data_ptr<index_t>(),
407 max_indices.mutable_data_ptr<index_t>(),
408 padding_idx, weight.size(0));
409 C10_CUDA_KERNEL_LAUNCH_CHECK();
410 } else {
411 EmbeddingBag_updateOutputKernel_sum_mean<scalar_t, index_t><<<grid, block, 0, stream>>>(
412 indices.const_data_ptr<index_t>(), offsets.const_data_ptr<index_t>(),
413 weight.const_data_ptr<scalar_t>(), output.mutable_data_ptr<scalar_t>(),
414 offset2bag.mutable_data_ptr<index_t>(), numIndices, numBags, featureSize,
415 weight.stride(0), weight.stride(1), mode, bag_size.mutable_data_ptr<index_t>(),
416 per_sample_weights.defined() ? per_sample_weights.const_data_ptr<scalar_t>() : NULL,
417 per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
418 padding_idx, weight.size(0));
419 C10_CUDA_KERNEL_LAUNCH_CHECK();
420 }
421 });
422 });
423
424 return std::tuple<Tensor, Tensor, Tensor, Tensor>(output, offset2bag, bag_size, max_indices);
425 }
426
_embedding_bag_dense_backward_cuda(const Tensor & grad_,const Tensor & indices,const Tensor & offset2bag,const Tensor & bag_size_,const Tensor & max_indices,int64_t num_weights,bool scale_grad_by_freq,int64_t mode,const std::optional<Tensor> & per_sample_weights_opt,int64_t padding_idx)427 Tensor _embedding_bag_dense_backward_cuda(const Tensor &grad_, const Tensor &indices,
428 const Tensor &offset2bag,
429 const Tensor &bag_size_,
430 const Tensor &max_indices,
431 int64_t num_weights,
432 bool scale_grad_by_freq, int64_t mode, const std::optional<Tensor>& per_sample_weights_opt,
433 int64_t padding_idx) {
434 // See [Note: hacky wrapper removal for optional tensor]
435 c10::MaybeOwned<Tensor> per_sample_weights_maybe_owned = at::borrow_from_optional_tensor(per_sample_weights_opt);
436 const Tensor& per_sample_weights = *per_sample_weights_maybe_owned;
437
438 // indices, offsets and offset2bag are assumed having correct dtypes and
439 // contiguous here due to the checks in _embedding_bag_backward in
440 // EmbeddingBag.cpp.
441 // Also see NOTE [ embedding_bag Native Functions ] in native_functions.yaml
442 // for more details.
443
444 Tensor grad = grad_.contiguous();
445 auto indices_arg = TensorArg(indices, "indices", 1);
446 auto grad_arg = TensorArg(grad, "grad", 1);
447 checkSameGPU("embedding_bag_cuda", grad_arg, indices_arg);
448
449
450 switch (static_cast<EmbeddingBagMode>(mode)) {
451 case EmbeddingBagMode::SUM:
452 case EmbeddingBagMode::MEAN:
453 if (mode == EmbeddingBagMode::MEAN)
454 AT_ASSERT(!per_sample_weights.defined());
455 return embedding_bag_backward_cuda_sum_avg(grad, indices, offset2bag,
456 bag_size_, num_weights, scale_grad_by_freq, mode,
457 per_sample_weights, padding_idx);
458
459 case EmbeddingBagMode::MAX:
460 AT_ASSERT(!per_sample_weights.defined());
461 return embedding_bag_backward_cuda_max(grad, max_indices, num_weights,
462 padding_idx);
463
464 default:
465 AT_ERROR(
466 "Unknown mode for embedding_bag_backward_cuda ", mode);
467 }
468 }
469
470 template <typename scalar_t, typename index_t>
_embedding_bag_per_sample_weights_backward_kernel(const scalar_t * grad,int64_t grad_stride0,int64_t grad_stride1,const scalar_t * weight,int64_t weight_stride0,int64_t weight_stride1,const index_t * indices,const index_t * offset2bag,int64_t num_samples,int64_t embedding_features,scalar_t * output,index_t padding_idx)471 __global__ static void _embedding_bag_per_sample_weights_backward_kernel(
472 const scalar_t* grad, int64_t grad_stride0, int64_t grad_stride1,
473 const scalar_t* weight, int64_t weight_stride0, int64_t weight_stride1,
474 const index_t* indices, // contiguous
475 const index_t* offset2bag, // contiguous
476 int64_t num_samples,
477 int64_t embedding_features,
478 scalar_t* output,
479 index_t padding_idx) {
480 using accscalar_t = acc_type<scalar_t, true>;
481 const int idx = threadIdx.x + blockIdx.x * blockDim.x;
482 const int warp = idx / C10_WARP_SIZE;
483 const int thread_in_warp = idx % C10_WARP_SIZE;
484 const int num_warps = blockDim.x * gridDim.x / C10_WARP_SIZE;
485
486 // Each warp is responsible for the accumulation of one sample.
487 // This involves doing one dot product between grad[bag_idx] and weight[embedding_idx].
488 for (int sample_idx = warp; sample_idx < num_samples; sample_idx += num_warps) {
489 accscalar_t result = 0.;
490 const int bag_idx = (int)offset2bag[sample_idx];
491 const int embedding_idx = (int)indices[sample_idx];
492 if (embedding_idx != padding_idx) {
493 for (int feature_idx = thread_in_warp; feature_idx < embedding_features;
494 feature_idx += C10_WARP_SIZE) {
495 result +=
496 grad[grad_stride0 * bag_idx + grad_stride1 * feature_idx] *
497 weight[weight_stride0 * embedding_idx + weight_stride1 * feature_idx];
498 }
499 }
500 result = cuda_utils::WarpReduceSum<accscalar_t>(result);
501 if (thread_in_warp == 0) {
502 output[sample_idx] = result;
503 }
504 }
505 }
506
_embedding_bag_per_sample_weights_backward_cuda(const Tensor & grad,const Tensor & weight,const Tensor & indices_,const Tensor & offsets_,const Tensor & offset2bag,int64_t mode,int64_t padding_idx)507 Tensor _embedding_bag_per_sample_weights_backward_cuda(
508 const Tensor& grad,
509 const Tensor& weight, // NB: embedding table, not per_sample_weights
510 const Tensor& indices_,
511 const Tensor& offsets_,
512 const Tensor& offset2bag,
513 int64_t mode,
514 int64_t padding_idx) {
515 TORCH_CHECK(
516 mode == EmbeddingBagMode::SUM,
517 "embedding_bag_backward: per_sample_weights only supported for mode='sum'");
518
519 AT_ASSERT(grad.dim() == 2);
520 auto embedding_features = grad.size(1);
521
522 Tensor indices, offsets;
523 std::tie(indices, offsets) = promoteIndicesAndOffsets(indices_, offsets_);
524 AT_ASSERT(indices.dim() == 1);
525 auto num_samples = indices.size(0);
526
527 AT_ASSERT(weight.dim() == 2);
528 AT_ASSERT(weight.size(1) == embedding_features);
529
530 const int threads_per_block = 512;
531 const int warps_per_block = threads_per_block / at::cuda::warp_size();
532
533 dim3 block(threads_per_block);
534 dim3 grid((num_samples + warps_per_block - 1) / warps_per_block);
535
536 auto output = at::empty({num_samples}, grad.options());
537
538 // Early return when there is no samples in the batch. This saves unnecessary kernel
539 // launch, but also prevents cudaGetLastError() to complain about invalid launch args
540 if (num_samples == 0) {
541 return output;
542 }
543
544 AT_DISPATCH_FLOATING_TYPES_AND_HALF(
545 grad.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
546 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "_embedding_bag_per_sample_weights_backward_cuda", [&]() {
547 _embedding_bag_per_sample_weights_backward_kernel<scalar_t, index_t>
548 <<<grid, block, 0, at::cuda::getCurrentCUDAStream()>>>(
549 grad.const_data_ptr<scalar_t>(), grad.stride(0), grad.stride(1),
550 weight.const_data_ptr<scalar_t>(), weight.stride(0), weight.stride(1),
551 indices.const_data_ptr<index_t>(),
552 offset2bag.const_data_ptr<index_t>(),
553 num_samples,
554 embedding_features,
555 output.mutable_data_ptr<scalar_t>(),
556 padding_idx);
557 C10_CUDA_KERNEL_LAUNCH_CHECK();
558 });
559 }
560 );
561 return output;
562 }
563
564 } // namespace at::native
565