1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <c10/util/Exception.h>
9 #include <c10/macros/Macros.h>
10
11 #include <ATen/cuda/cub.cuh>
12
13 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
14 #include <ATen/native/cuda/SortingCommon.cuh>
15 #include <ATen/native/cuda/block_reduce.cuh>
16 #include <ATen/native/cuda/thread_constants.h>
17
18 #if CUB_SUPPORTS_SCAN_BY_KEY()
19 #include <thrust/iterator/reverse_iterator.h>
20 #endif
21
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/arange.h>
27 #include <ATen/ops/embedding_dense_backward_native.h>
28 #include <ATen/ops/embedding_renorm_native.h>
29 #include <ATen/ops/empty.h>
30 #include <ATen/ops/empty_like.h>
31 #include <ATen/ops/zeros.h>
32 #endif
33
34 namespace at::native {
35
36 namespace {
37
38 #if defined(USE_ROCM)
39 static const int BLOCKDIMY = 16;
40 #else
41 static const int BLOCKDIMY = 32;
42 #endif
43
44 template
45 <typename scalar_t,
46 typename accscalar_t,
47 typename index_t>
embedding_backward_feature_kernel(const index_t * indices,const scalar_t * __restrict__ grad,scalar_t * __restrict__ grad_weight,int n,int64_t stride,int padding_idx)48 __global__ void embedding_backward_feature_kernel
49 (const index_t* indices,
50 const scalar_t* __restrict__ grad,
51 scalar_t* __restrict__ grad_weight,
52 int n, // OK to pass as int, we don't expect 2 billion+ samples in one shot
53 int64_t stride,
54 int padding_idx)
55 {
56 extern __shared__ char buf[];
57 accscalar_t* smem = (accscalar_t*)buf;
58 accscalar_t* my_s = smem + C10_WARP_SIZE*threadIdx.y;
59 int* indices_batch = (int*)(buf + sizeof(accscalar_t)*C10_WARP_SIZE*blockDim.y);
60
61 const int s = (int)stride; // OK to make int, we don't expect 2 billion+ embedding row size
62
63 const int f = threadIdx.x + blockIdx.x*blockDim.x; // feature_dim
64
65 for(int batch_start = 0; batch_start < n; batch_start += blockDim.x*blockDim.y)
66 {
67 // Entire block cooperates to load a batch of 1024 indices to process
68 int tid = threadIdx.x + threadIdx.y*blockDim.x;
69 if(batch_start + tid < n)
70 indices_batch[tid] = (int)indices[batch_start + tid];
71
72 int batch_end = batch_start + blockDim.x*blockDim.y < n ?
73 batch_start + blockDim.x*blockDim.y : n;
74
75 // Loop over the batch of <= 1024 loaded indices in chunks of blockDim.y = 32
76 for(int chunk_start = batch_start; chunk_start < batch_end; chunk_start += blockDim.y)
77 {
78 // This does double duty: it makes sure indices_batch is ready, and it makes sure match-group
79 // leaders are done with their accumulates before other warps start loading again.
80 __syncthreads();
81
82 int n_this_chunk = (batch_end - chunk_start) < blockDim.y ?
83 (batch_end - chunk_start) : blockDim.y;
84
85 int src_row = chunk_start + threadIdx.y;
86 int dst_row = indices_batch[src_row - batch_start]; // This warp's target row in grad_weight
87
88 // All warps load their smem segments with incoming grad data
89 if(src_row < n && f < s && dst_row != padding_idx)
90 my_s[threadIdx.x] = static_cast<accscalar_t>(grad[src_row*stride + f]);
91
92 __syncthreads();
93
94 // To ensure determinism, we can't just have each warp add its grad data to its dst_row.
95 // We need to check if any other warps pulled grad data targeting dst_row.
96 // If so, we elect the first warp in each matching group as the leader.
97 // Each leader warp serializes the accumulates targeting dst_row in shared memory,
98 // then finishes by adding the accumulated buffer to dst_row in grad_weight.
99 if(dst_row != padding_idx && src_row < n) // Per-warp exit condition, safe with ballot_sync
100 {
101 int match_found_this_thread = 0;
102 if(threadIdx.x < n_this_chunk)
103 match_found_this_thread = (dst_row == indices_batch[chunk_start - batch_start + threadIdx.x]);
104 #if defined(USE_ROCM)
105 unsigned long long int matchmask = WARP_BALLOT(match_found_this_thread);
106 int first_remaining_peer = __ffsll(matchmask) - 1;
107 #else
108 unsigned int matchmask = WARP_BALLOT(match_found_this_thread);
109 int first_remaining_peer = __ffs(matchmask) - 1;
110 #endif
111
112 if(threadIdx.y == first_remaining_peer) // Nominate lowest-indexed warp as the leader
113 {
114 matchmask ^= (1 << first_remaining_peer);
115 while(matchmask)
116 {
117 #if defined(USE_ROCM)
118 first_remaining_peer = __ffsll(matchmask) - 1;
119 #else
120 first_remaining_peer = __ffs(matchmask) - 1;
121 #endif
122 my_s[threadIdx.x] += smem[threadIdx.x + C10_WARP_SIZE*first_remaining_peer];
123 matchmask ^= (1 << first_remaining_peer);
124 }
125 if(f < s)
126 grad_weight[dst_row*stride + f] += static_cast<scalar_t>(my_s[threadIdx.x]);
127 }
128 }
129 }
130 }
131 }
132
133
134 template <typename scalar_t, typename index_t>
embedding_backward_kernel(index_t * input,index_t * indices,scalar_t * grad_output,scalar_t * grad_weight,index_t * count,int64_t numel,int64_t stride,int padding_idx)135 __global__ void embedding_backward_kernel(
136 index_t* input, index_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
137 index_t* count, int64_t numel, int64_t stride, int padding_idx) {
138
139 using accscalar_t = acc_type<scalar_t, true>;
140 int idx = blockIdx.x * 4 + threadIdx.y;
141
142 // Each warp is responsible for an input into the LookupTable.
143 // If the preceding input has the same as this input, then the warp
144 // exits immediately. The warp also processes subsequent inputs with the
145 // same value.
146 //
147 // Input Warp
148 // 1 <warp 1>
149 // 1 <warp 1> (<warp 2> exits without doing any work)
150 // 5 <warp 3>
151 // 8 <warp 4>
152
153 // Number of values processed by each thread (grain size)
154 const int SZ = 4;
155
156 if (idx < numel
157 && (idx == 0 || input[idx] != input[idx - 1])
158 && input[idx] != padding_idx) {
159 do {
160 const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
161 const int weight_row = ((int) input[idx]) * stride;
162 const int grad_row = ((int) indices[idx]) * stride;
163 const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
164
165 accscalar_t gradient[SZ];
166 accscalar_t weight[SZ];
167
168 #pragma unroll
169 for (int ii = 0; ii < SZ; ii++) {
170 int feature_dim = start_feature + ii * C10_WARP_SIZE;
171 if (feature_dim < stride) {
172 gradient[ii] = static_cast<accscalar_t>(grad_output[grad_row + feature_dim]);
173 weight[ii] = static_cast<accscalar_t>(grad_weight[weight_row + feature_dim]);
174 }
175 }
176
177 #pragma unroll
178 for (int ii = 0; ii < SZ; ii++) {
179 weight[ii] += gradient[ii] * scale;
180 }
181
182 #pragma unroll
183 for (int ii = 0; ii < SZ; ii++) {
184 int feature_dim = start_feature + ii * C10_WARP_SIZE;
185 if (feature_dim < stride) {
186 grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(weight[ii]);
187 }
188 }
189
190 idx++;
191 } while (idx < numel && input[idx] == input[idx - 1]);
192 }
193 }
194
195 /* Calculate norms of the rows of weight_ptr given by idx_ptr and capture them in norms */
196 template <typename scalar_t, typename accscalar_t, typename index_t>
renorm_kernel(scalar_t * weights,index_t * indices,accscalar_t max_norm,accscalar_t norm_type,int64_t dim,int64_t weights_stride0,int64_t weights_stride1,const int64_t * num_unique_indices)197 __global__ void renorm_kernel(
198 scalar_t* weights, index_t* indices, accscalar_t max_norm,
199 accscalar_t norm_type, int64_t dim,
200 int64_t weights_stride0, int64_t weights_stride1,
201 const int64_t *num_unique_indices) {
202 if (blockIdx.x >= *num_unique_indices) {
203 return;
204 }
205
206 // Some casting hacks since dynamic shared memory and templates don't work together:
207 extern __shared__ unsigned char smem[];
208 auto sdata = reinterpret_cast<accscalar_t*>(smem);
209
210 int tid = threadIdx.x;
211 int base_index = indices[blockIdx.x] * weights_stride0;
212
213 accscalar_t v = 0;
214 for (int i = tid; i < dim; i += blockDim.x) {
215 auto x = static_cast<accscalar_t>(weights[base_index + i * weights_stride1]);
216 if (norm_type == 1) {
217 v += std::abs(x);
218 } else if (norm_type == 2) {
219 v += x * x;
220 } else {
221 v += std::pow(x, norm_type);
222 }
223 }
224
225 v = cuda_utils::BlockReduceSum(v, sdata);
226
227 if (tid == 0) {
228 sdata[0] = std::pow(v, static_cast<accscalar_t>(1.0 / norm_type));
229 }
230 __syncthreads();
231
232 // now we renormalize the blocks that need it
233 if (sdata[0] > max_norm) {
234 auto factor = static_cast<scalar_t>(max_norm / (sdata[0] + 1e-7));
235 for (int i = tid; i < dim; i += blockDim.x) {
236 weights[base_index + i * weights_stride1] *= factor;
237 }
238 }
239 }
240
241 } // anonymous namespace
242
243 #if !CUB_SUPPORTS_SCAN_BY_KEY()
244 template<typename index_t>
245 void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
246 #endif
247
embedding_dense_backward_cuda(const Tensor & grad_,const Tensor & indices_,int64_t num_weights,int64_t padding_idx,bool scale_grad_by_freq)248 Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
249 int64_t num_weights, int64_t padding_idx,
250 bool scale_grad_by_freq) {
251 auto grad_arg = TensorArg(grad_, "grad", 1);
252 auto indices_arg = TensorArg(indices_, "indices", 1);
253 checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
254 checkSameGPU("embedding_backward", grad_arg, indices_arg);
255
256 auto indices = indices_.contiguous();
257
258 auto num_indices = indices.numel();
259 auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
260 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
261
262 if (num_indices <= 3072 && !scale_grad_by_freq) {
263 auto indices_contig = indices.contiguous();
264 auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
265 int64_t stride = grad_weight.stride(0);
266 int warp_size = at::cuda::warp_size();
267 dim3 grid(ceil_div(stride, (int64_t)warp_size));
268 dim3 block(warp_size, BLOCKDIMY);
269
270 AT_DISPATCH_FLOATING_TYPES_AND2(
271 at::ScalarType::Half, at::ScalarType::BFloat16,
272 grad.scalar_type(),
273 "embedding_backward",
274 [&]
275 {
276 using accscalar_t = acc_type<scalar_t, true>;
277 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
278 embedding_backward_feature_kernel<scalar_t, accscalar_t, index_t>
279 <<<grid,
280 block,
281 sizeof(accscalar_t)*warp_size*BLOCKDIMY + sizeof(int)*warp_size*BLOCKDIMY,
282 stream>>>
283 (indices_contig.const_data_ptr<index_t>(),
284 grad.const_data_ptr<scalar_t>(),
285 grad_weight.mutable_data_ptr<scalar_t>(),
286 static_cast<int>(num_indices),
287 static_cast<int64_t>(stride),
288 static_cast<int>(padding_idx));
289 C10_CUDA_KERNEL_LAUNCH_CHECK();
290 });
291 });
292 return grad_weight;
293 }
294
295 auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
296 auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
297 Tensor count;
298 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
299 auto range = at::arange(num_indices, indices.options());
300 int64_t nbits = cuda::cub::get_num_bits(num_weights);
301 cuda::cub::radix_sort_pairs(
302 indices.const_data_ptr<index_t>(), sorted_indices.mutable_data_ptr<index_t>(),
303 range.const_data_ptr<index_t>(), orig_indices.mutable_data_ptr<index_t>(),
304 num_indices, false/*, 0, nbits*/);
305 });
306
307 if (scale_grad_by_freq) {
308 count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
309 #if CUB_SUPPORTS_SCAN_BY_KEY()
310 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
311 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
312
313 // Compute an increasing sequence per unique item in sortedIndices:
314 // sorted: 2 5 5 5 7 7 8 9 9
315 // count: 1 1 2 3 1 2 1 1 2
316 auto sorted_data = sorted_indices.const_data_ptr<index_t>();
317 auto count_data = count.mutable_data_ptr<index_t>();
318 cuda::cub::inclusive_sum_by_key(
319 sorted_data,
320 at_cuda_detail::cub::ConstantInputIterator<index_t>(1),
321 count_data,
322 num_indices
323 );
324
325 // Take the maximum of each count per unique key in reverse:
326 // sorted: 2 5 5 5 7 7 8 9 9
327 // count: 1 3 3 3 2 2 1 2 2
328 cuda::cub::inclusive_scan_by_key(
329 thrust::make_reverse_iterator(sorted_data + num_indices),
330 thrust::make_reverse_iterator(static_cast<const index_t*>(count_data) + num_indices),
331 thrust::make_reverse_iterator(count_data + num_indices),
332 at_cuda_detail::cub::Max(),
333 num_indices
334 );
335 });
336 #else
337 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
338 embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
339 });
340 #endif
341 }
342
343 return embedding_backward_cuda_kernel(grad, orig_indices,
344 sorted_indices, count, num_weights, padding_idx);
345 }
346
embedding_renorm_cuda_(Tensor & self,const Tensor & indices,double max_norm,double norm_type)347 Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
348 double max_norm, double norm_type) {
349 auto self_arg = TensorArg(self, "self", 1);
350 auto indices_arg = TensorArg(indices, "indices", 1);
351 checkDim("embedding_renorm_", self_arg, 2);
352 checkSameGPU("embedding_renorm", self_arg, indices_arg);
353
354 cudaStream_t stream = at::cuda::getCurrentCUDAStream();
355
356 AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () {
357
358 auto num_indices = indices.numel();
359 auto indices_contig = std::get<0>(indices.sort()).contiguous();
360 auto unique_indices = at::empty(indices.numel(), indices.options());
361 auto num_unique_indices = at::empty({}, indices.options().dtype(kLong));
362
363 cuda::cub::unique(
364 indices_contig.const_data_ptr<index_t>(),
365 unique_indices.mutable_data_ptr<index_t>(),
366 num_unique_indices.mutable_data_ptr<int64_t>(),
367 num_indices
368 );
369
370 int warp_size = at::cuda::warp_size();
371 TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 &&
372 num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads,
373 "BlockReduceSum requires all warps be active");
374 const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr<int64_t>();
375 dim3 grid = unique_indices.numel();
376 dim3 block = num_threads();
377 int dim = self.stride(0);
378
379 AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_renorm_cuda_", [&] {
380 using accscalar_t = acc_type<scalar_t, true>;
381 renorm_kernel<<<grid, block, (block.x / warp_size) * sizeof(accscalar_t), stream>>>(
382 self.mutable_data_ptr<scalar_t>(),
383 unique_indices.const_data_ptr<index_t>(),
384 static_cast<accscalar_t>(max_norm),
385 static_cast<accscalar_t>(norm_type),
386 dim, self.stride(0), self.stride(1),
387 num_unique_indices_ptr);
388 C10_CUDA_KERNEL_LAUNCH_CHECK();
389 });
390 });
391 return self;
392 }
393
394
395 } // namespace at::native
396