1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
3 #include <ATen/cuda/Atomic.cuh>
4 #include <ATen/cuda/CUDAContext.h>
5 #include <ATen/cuda/cub.cuh>
6 #include <ATen/AccumulateType.h>
7 #include <ATen/Dispatch.h>
8 #include <ATen/TensorUtils.h>
9 #include <ATen/native/cuda/SortingCommon.cuh>
10
11 #include <c10/macros/Macros.h>
12
13 #if CUB_SUPPORTS_UNIQUE_BY_KEY()
14 #include <thrust/iterator/counting_iterator.h>
15 #endif
16
17 #ifndef AT_PER_OPERATOR_HEADERS
18 #include <ATen/Functions.h>
19 #else
20 #include <ATen/ops/empty.h>
21 #include <ATen/ops/zeros.h>
22 #endif
23
24 namespace at::native {
25
26 namespace {
27
28 /* This code computes the sum of the weights in two-steps:
29 1) Each GPU warp sums `NROWS_PER_THREAD` number of row given by `indeces`
30 2) Each partial-sum from 1) are summed and scatter into `grad_weight`
31
32 Notice, `NROWS_PER_THREAD` impacts the Achieved Occupancy of the
33 kernel execution. If it is high, the size of the thread blocks will be
34 too small to achieve good occupancy. Similarly, a very low value will
35 make the size of the thread blocks in the final sum in step 2) too small.
36 */
37 constexpr int NROWS_PER_THREAD = 10;
38
39 // Fast ceil division (no overflow checking)
40 __host__ __device__ __forceinline__
ceil_div(int64_t x,int64_t y)41 int64_t ceil_div(int64_t x, int64_t y) {
42 return (x + y - 1) / y;
43 }
44
45 template <typename index_t>
46 __global__
krn_partials_per_segment(index_t * ret,const index_t * segment_offsets,const int64_t * num_of_segments_ptr,int64_t numel)47 void krn_partials_per_segment(index_t *ret, const index_t *segment_offsets,
48 const int64_t *num_of_segments_ptr, int64_t numel) {
49 int64_t num_of_segments = *num_of_segments_ptr;
50 const int id = blockIdx.x * blockDim.x + threadIdx.x;
51 if(id < num_of_segments) {
52 const int64_t idx_start = segment_offsets[id];
53 const int64_t idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
54 const int64_t size = idx_end - idx_start;
55 ret[id] = ceil_div(size, NROWS_PER_THREAD);
56 }
57 }
58
59 template <typename index_t>
60 __global__
krn_partial_segment_offset(index_t * ret,const index_t * partials_per_segment,const index_t * partials_per_segment_offset,const index_t * segment_offsets,const int64_t * num_of_segments_ptr)61 void krn_partial_segment_offset(
62 index_t *ret,
63 const index_t *partials_per_segment,
64 const index_t *partials_per_segment_offset,
65 const index_t *segment_offsets,
66 const int64_t *num_of_segments_ptr) {
67 int64_t num_of_segments = *num_of_segments_ptr;
68 const int id = blockIdx.x * blockDim.x + threadIdx.x;
69 if(id < num_of_segments) {
70 index_t idx = partials_per_segment_offset[id];
71 const index_t num_partials = partials_per_segment[id];
72 const index_t segment_offset = segment_offsets[id];
73 for (int64_t i=0; i<num_partials; ++i) {
74 ret[idx++] = segment_offset + i * NROWS_PER_THREAD;
75 }
76 }
77 }
78
79
80 template <typename scalar_t, typename index_t>
compute_grad_weight_bags(const index_t * indices,const scalar_t * gradOutput,const index_t * offset2bag,const index_t * count,ptrdiff_t numel,int64_t stride,int mode_mean,const index_t * bag_size,const scalar_t * per_sample_weights,int64_t per_sample_weights_stride,const index_t * segment_offsets,const int64_t * num_of_segments_ptr,acc_type<scalar_t,true> * grad_weight_per_segment,const int64_t stride_warped)81 __global__ void compute_grad_weight_bags(
82 const index_t *indices, const scalar_t *gradOutput,
83 const index_t *offset2bag, const index_t *count, ptrdiff_t numel,
84 int64_t stride, int mode_mean, const index_t *bag_size,
85 const scalar_t* per_sample_weights, int64_t per_sample_weights_stride,
86 const index_t* segment_offsets, const int64_t *num_of_segments_ptr,
87 acc_type<scalar_t, true> *grad_weight_per_segment,
88 const int64_t stride_warped) {
89
90 int64_t num_of_segments = *num_of_segments_ptr;
91 const int gid = blockIdx.x * blockDim.x + threadIdx.x;
92 const int id = gid / stride_warped;
93 const int startFeature = gid % stride_warped;
94 if (startFeature >= stride) {
95 return;
96 }
97 if (id >= num_of_segments) {
98 return;
99 }
100 const int idx_begin = segment_offsets[id];
101 const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
102
103 acc_type<scalar_t, true> weight = 0;
104 for (int idx=idx_begin; idx < idx_end; ++idx) {
105 const int origRow = indices[idx];
106 const int seq_number = offset2bag[origRow];
107 const int gradOutputRow = seq_number * stride;
108
109 acc_type<scalar_t, true> scale = count ? 1.0 / count[idx] : 1.0;
110 if (per_sample_weights) {
111 scale *= per_sample_weights[origRow * per_sample_weights_stride];
112 }
113
114 acc_type<scalar_t, true> gradient = gradOutput[gradOutputRow + startFeature];
115 if (mode_mean) {
116 gradient /= bag_size[seq_number];
117 }
118 weight += gradient * scale;
119 }
120 grad_weight_per_segment[id * stride + startFeature] = weight;
121 }
122
123 template <typename scalar_t, typename index_t>
compute_grad_weight(const index_t * indices,const scalar_t * gradOutput,const index_t * count,ptrdiff_t numel,int64_t stride,const index_t * segment_offsets,const int64_t * num_of_segments_ptr,acc_type<scalar_t,true> * grad_weight_per_segment,const int64_t stride_warped)124 __global__ void compute_grad_weight(
125 const index_t *indices,
126 const scalar_t *gradOutput,
127 const index_t *count,
128 ptrdiff_t numel,
129 int64_t stride,
130 const index_t* segment_offsets,
131 const int64_t *num_of_segments_ptr,
132 acc_type<scalar_t, true> *grad_weight_per_segment,
133 const int64_t stride_warped) {
134
135 int64_t num_of_segments = *num_of_segments_ptr;
136 using accscalar_t = acc_type<scalar_t, true>;
137 const int gid = blockIdx.x * blockDim.x + threadIdx.x;
138 const int id = gid / stride_warped;
139 const int startFeature = gid % stride_warped;
140 if (startFeature >= stride) {
141 return;
142 }
143 if (id >= num_of_segments) {
144 return;
145 }
146 const int idx_begin = segment_offsets[id];
147 const int idx_end = (id == num_of_segments-1)?numel:segment_offsets[id+1];
148
149 accscalar_t weight = 0;
150 for (int idx=idx_begin; idx < idx_end; ++idx) {
151 const index_t target_row = indices[idx];
152 const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
153 weight += gradOutput[target_row * stride + startFeature] * scale;
154 }
155 grad_weight_per_segment[id * stride + startFeature] = weight;
156 }
157
158 // This kernel assumes that all input tensors are contiguous.
159 template <typename scalar_t, typename index_t>
sum_and_scatter(const index_t * input,scalar_t * gradWeight,int64_t stride,const index_t * segment_offsets,const int64_t * num_of_segments_ptr,const acc_type<scalar_t,true> * grad_weight_per_segment,const index_t * segment_sizes_offsets,const int64_t * num_of_partial_segments_ptr,const int64_t padding_idx,const int64_t stride_warped)160 __global__ void sum_and_scatter(
161 const index_t *input, scalar_t *gradWeight, int64_t stride,
162 const index_t* segment_offsets, const int64_t *num_of_segments_ptr,
163 const acc_type<scalar_t, true> *grad_weight_per_segment,
164 const index_t *segment_sizes_offsets, const int64_t *num_of_partial_segments_ptr,
165 const int64_t padding_idx,
166 const int64_t stride_warped) {
167
168 int64_t num_of_segments = *num_of_segments_ptr;
169 int64_t num_of_partial_segments = *num_of_partial_segments_ptr;
170 const int gid = blockIdx.x * blockDim.x + threadIdx.x;
171 const int id = gid / stride_warped;
172 const int startFeature = gid % stride_warped;
173 if (startFeature >= stride) {
174 return;
175 }
176 if (id >= num_of_segments) {
177 return;
178 }
179
180 const int idx_begin = segment_sizes_offsets[id];
181 const int idx_end = (id == num_of_segments-1)?num_of_partial_segments:segment_sizes_offsets[id+1];
182 acc_type<scalar_t, true> weight = 0;
183 for (int idx=idx_begin; idx < idx_end; ++idx) {
184 weight += grad_weight_per_segment[idx*stride + startFeature];
185 }
186 int64_t target_row = input[segment_offsets[id]];
187 if (target_row != padding_idx) {
188 gradWeight[target_row * stride + startFeature] = weight;
189 }
190 }
191
192 template<typename index_t>
compute_num_of_partial_segments(const index_t * partials_per_segment,const index_t * partials_per_segment_offset,const int64_t * num_of_segments_ptr,int64_t * output)193 __global__ void compute_num_of_partial_segments(const index_t *partials_per_segment, const index_t *partials_per_segment_offset, const int64_t *num_of_segments_ptr, int64_t *output) {
194 int64_t num_of_segments = *num_of_segments_ptr;
195 *output = partials_per_segment[num_of_segments-1] +
196 partials_per_segment_offset[num_of_segments-1];
197 }
198
199 #if !CUB_SUPPORTS_UNIQUE_BY_KEY()
write_num_of_segments_for_legacy_thrust_path(int64_t * num_of_segments_ptr,int64_t num_of_segments)200 __global__ void write_num_of_segments_for_legacy_thrust_path(int64_t *num_of_segments_ptr, int64_t num_of_segments) {
201 *num_of_segments_ptr = num_of_segments;
202 }
203 #endif
204
205 } // anon namespace
206
207 #if !CUB_SUPPORTS_UNIQUE_BY_KEY()
208 template<typename index_t>
209 int64_t embedding_backward_cuda_kernel_unique_by_key(const Tensor &sorted_indices, Tensor &segment_offsets);
210 #endif
211
embedding_backward_cuda_kernel(const Tensor & grad,const Tensor & orig_indices,const Tensor & sorted_indices,const Tensor & count,int64_t num_weights,int padding_idx,bool mode_mean,const Tensor & offset2bag,const Tensor & bag_size,const Tensor & per_sample_weights)212 Tensor embedding_backward_cuda_kernel(
213 const Tensor &grad,
214 const Tensor &orig_indices,
215 const Tensor &sorted_indices,
216 const Tensor &count,
217 int64_t num_weights,
218 int padding_idx,
219 bool mode_mean,
220 const Tensor &offset2bag,
221 const Tensor &bag_size,
222 const Tensor &per_sample_weights) {
223
224 auto stream = at::cuda::getCurrentCUDAStream();
225 const ptrdiff_t numel = sorted_indices.numel();
226
227 auto grad_weight = at::zeros({num_weights, grad.size(-1)}, grad.options());
228 const int64_t stride = grad_weight.stride(0);
229
230 // Compute the number of segments and their start position so that we do not have to
231 // spawn a warp per index. In this context, a segment is a number of rows that should
232 // be summarized.
233 // Unit: index in `sorted_indices` and `orig_indices`
234 auto segment_offsets = at::empty({numel}, orig_indices.options());
235 auto num_of_segments_tensor = at::empty({}, grad.options().dtype(kLong));
236 int64_t *num_of_segments_ptr = num_of_segments_tensor.mutable_data_ptr<int64_t>();
237 #if !CUB_SUPPORTS_UNIQUE_BY_KEY()
238 AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
239 int64_t num_of_segments = embedding_backward_cuda_kernel_unique_by_key<index_t>(sorted_indices, segment_offsets);
240 write_num_of_segments_for_legacy_thrust_path<<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(num_of_segments_ptr, num_of_segments);
241 C10_CUDA_KERNEL_LAUNCH_CHECK();
242 });
243 #else
244 AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
245 cuda::cub::unique_by_key(
246 sorted_indices.const_data_ptr<index_t>(), thrust::make_counting_iterator(0),
247 segment_offsets.mutable_data_ptr<index_t>(),
248 num_of_segments_ptr, sorted_indices.numel());
249 });
250 #endif
251
252 int64_t max_segments = std::min<int64_t>(numel, num_weights);
253
254 AT_DISPATCH_INDEX_TYPES(orig_indices.scalar_type(), "embedding_backward_cuda_kernel", [&] () {
255 // We split the segments up into sizes of `NROWS_PER_THREAD`
256 // Compute the number partial-segments per segment (some partial-segments
257 // may not be the full `NROWS_PER_THREAD` number of rows)
258 auto partials_per_segment = at::empty({max_segments}, orig_indices.options());
259 {
260 krn_partials_per_segment<<<ceil_div(max_segments, 32), 32, 0, stream>>> (
261 partials_per_segment.mutable_data_ptr<index_t>(),
262 segment_offsets.const_data_ptr<index_t>(),
263 num_of_segments_ptr,
264 numel);
265 C10_CUDA_KERNEL_LAUNCH_CHECK();
266 }
267
268 // In order to compute `partial_segment_offset`, which is the start index
269 // of each partial-segment in `sorted_indices`, we need to compute the
270 // start position of each _segment_ in `partial_segment_offset`.
271 // Unit: index in `partial_segment_offset`
272 auto partials_per_segment_offset = at::empty({max_segments}, orig_indices.options());
273 cuda::cub::exclusive_sum(
274 partials_per_segment.const_data_ptr<index_t>(),
275 partials_per_segment_offset.mutable_data_ptr<index_t>(),
276 max_segments);
277
278 // The total number of partial-segments is the sum of `partials_per_segment_offset`
279 auto num_of_partial_segments_tensor = at::empty({}, grad.options().dtype(kLong));
280 int64_t *num_of_partial_segments_ptr = num_of_partial_segments_tensor.mutable_data_ptr<int64_t>();
281 compute_num_of_partial_segments<index_t><<<1, 1, 0, c10::cuda::getCurrentCUDAStream()>>>(
282 partials_per_segment.const_data_ptr<index_t>(),
283 partials_per_segment_offset.const_data_ptr<index_t>(),
284 num_of_segments_ptr, num_of_partial_segments_ptr);
285 C10_CUDA_KERNEL_LAUNCH_CHECK();
286
287 auto max_partial_segment = numel / NROWS_PER_THREAD + max_segments;
288
289 // Now we can compute the start position of each partial-segment
290 // Unit: index in `sorted_indices` and `orig_indices`
291 auto partial_segment_offset = at::empty({max_partial_segment}, orig_indices.options());
292 {
293 krn_partial_segment_offset<<<ceil_div(max_segments, 32), 32, 0, stream>>> (
294 partial_segment_offset.mutable_data_ptr<index_t>(),
295 partials_per_segment.const_data_ptr<index_t>(),
296 partials_per_segment_offset.const_data_ptr<index_t>(),
297 segment_offsets.const_data_ptr<index_t>(),
298 num_of_segments_ptr);
299 C10_CUDA_KERNEL_LAUNCH_CHECK();
300 }
301
302 const int warp_size = at::cuda::warp_size();
303 const int stride_warped = ceil_div(stride, warp_size)*warp_size;
304 const int block = std::min(stride_warped, MAX_BLOCK_SIZE);
305 const int grid = ceil_div(max_partial_segment*stride_warped, block);
306
307 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16,
308 grad.scalar_type(), "embedding_bag_backward_cuda_compute_grad_weight", [&] {
309 // For numerical stability, the dtype of `grad_weight_per_segment`
310 // should match `acc_type`
311 using partial_weight_t = acc_type<scalar_t, true>;
312 TensorOptions op;
313 if(grad.dtype() == at::kHalf || grad.dtype() == at::kBFloat16) {
314 op = grad.options().dtype(at::kFloat);
315 } else {
316 op = grad.options();
317 }
318 auto grad_weight_per_segment = at::empty({max_partial_segment, stride}, op);
319 // Compute the sum of each partial-segment and handle bags
320 if (offset2bag.defined()) {
321 compute_grad_weight_bags<scalar_t><<<grid, block, 0, stream>>>(
322 orig_indices.const_data_ptr<index_t>(),
323 grad.const_data_ptr<scalar_t>(),
324 offset2bag.const_data_ptr<index_t>(),
325 count.defined() ? count.const_data_ptr<index_t>() : nullptr, numel, stride,
326 mode_mean, bag_size.const_data_ptr<index_t>(),
327 per_sample_weights.defined() ? per_sample_weights.const_data_ptr<scalar_t>() : NULL,
328 per_sample_weights.defined() ? per_sample_weights.stride(0) : 0,
329 partial_segment_offset.const_data_ptr<index_t>(),
330 num_of_partial_segments_ptr, grad_weight_per_segment.mutable_data_ptr<partial_weight_t>(),
331 stride_warped);
332 C10_CUDA_KERNEL_LAUNCH_CHECK();
333 } else {
334 compute_grad_weight<scalar_t><<<grid, block, 0, stream>>>(
335 orig_indices.const_data_ptr<index_t>(),
336 grad.const_data_ptr<scalar_t>(),
337 count.defined() ? count.const_data_ptr<index_t>() : nullptr,
338 numel, stride,
339 partial_segment_offset.const_data_ptr<index_t>(),
340 num_of_partial_segments_ptr,
341 grad_weight_per_segment.mutable_data_ptr<partial_weight_t>(),
342 stride_warped);
343 C10_CUDA_KERNEL_LAUNCH_CHECK();
344 }
345
346 // Finally, we sum all the partial-sums and scatter them
347 // into `grad_weight`.
348 const int grid2 = ceil_div(max_segments*stride_warped, block);
349 sum_and_scatter<scalar_t><<<grid2, block, 0, stream>>>(
350 sorted_indices.const_data_ptr<index_t>(),
351 grad_weight.mutable_data_ptr<scalar_t>(),
352 stride,
353 segment_offsets.const_data_ptr<index_t>(),
354 num_of_segments_ptr, grad_weight_per_segment.const_data_ptr<partial_weight_t>(),
355 partials_per_segment_offset.const_data_ptr<index_t>(),
356 num_of_partial_segments_ptr,
357 padding_idx,
358 stride_warped);
359 C10_CUDA_KERNEL_LAUNCH_CHECK();
360 });
361 });
362 return grad_weight;
363 }
364
365 } // namespace at::native
366