xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Embedding.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/TensorUtils.h>
6 #include <ATen/ceil_div.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <c10/util/Exception.h>
9 #include <c10/macros/Macros.h>
10 
11 #include <ATen/cuda/cub.cuh>
12 
13 #include <ATen/native/cuda/EmbeddingBackwardKernel.cuh>
14 #include <ATen/native/cuda/SortingCommon.cuh>
15 #include <ATen/native/cuda/block_reduce.cuh>
16 #include <ATen/native/cuda/thread_constants.h>
17 
18 #if CUB_SUPPORTS_SCAN_BY_KEY()
19 #include <thrust/iterator/reverse_iterator.h>
20 #endif
21 
22 #ifndef AT_PER_OPERATOR_HEADERS
23 #include <ATen/Functions.h>
24 #include <ATen/NativeFunctions.h>
25 #else
26 #include <ATen/ops/arange.h>
27 #include <ATen/ops/embedding_dense_backward_native.h>
28 #include <ATen/ops/embedding_renorm_native.h>
29 #include <ATen/ops/empty.h>
30 #include <ATen/ops/empty_like.h>
31 #include <ATen/ops/zeros.h>
32 #endif
33 
34 namespace at::native {
35 
36 namespace {
37 
38 #if defined(USE_ROCM)
39 static const int BLOCKDIMY = 16;
40 #else
41 static const int BLOCKDIMY = 32;
42 #endif
43 
44 template
45   <typename scalar_t,
46    typename accscalar_t,
47    typename index_t>
embedding_backward_feature_kernel(const index_t * indices,const scalar_t * __restrict__ grad,scalar_t * __restrict__ grad_weight,int n,int64_t stride,int padding_idx)48 __global__ void embedding_backward_feature_kernel
49   (const index_t* indices,
50    const scalar_t* __restrict__ grad,
51    scalar_t* __restrict__ grad_weight,
52    int n, // OK to pass as int, we don't expect 2 billion+ samples in one shot
53    int64_t stride,
54    int padding_idx)
55 {
56   extern __shared__ char buf[];
57   accscalar_t* smem = (accscalar_t*)buf;
58   accscalar_t* my_s = smem + C10_WARP_SIZE*threadIdx.y;
59   int* indices_batch = (int*)(buf + sizeof(accscalar_t)*C10_WARP_SIZE*blockDim.y);
60 
61   const int s = (int)stride; // OK to make int, we don't expect 2 billion+ embedding row size
62 
63   const int f = threadIdx.x + blockIdx.x*blockDim.x; // feature_dim
64 
65   for(int batch_start = 0; batch_start < n; batch_start += blockDim.x*blockDim.y)
66   {
67     // Entire block cooperates to load a batch of 1024 indices to process
68     int tid = threadIdx.x + threadIdx.y*blockDim.x;
69     if(batch_start + tid < n)
70       indices_batch[tid] = (int)indices[batch_start + tid];
71 
72     int batch_end = batch_start + blockDim.x*blockDim.y < n ?
73                     batch_start + blockDim.x*blockDim.y : n;
74 
75     // Loop over the batch of <= 1024 loaded indices in chunks of blockDim.y = 32
76     for(int chunk_start = batch_start; chunk_start < batch_end; chunk_start += blockDim.y)
77     {
78       // This does double duty:  it makes sure indices_batch is ready, and it makes sure match-group
79       // leaders are done with their accumulates before other warps start loading again.
80       __syncthreads();
81 
82       int n_this_chunk = (batch_end - chunk_start) < blockDim.y ?
83                          (batch_end - chunk_start) : blockDim.y;
84 
85       int src_row = chunk_start + threadIdx.y;
86       int dst_row = indices_batch[src_row - batch_start]; // This warp's target row in grad_weight
87 
88       // All warps load their smem segments with incoming grad data
89       if(src_row < n && f < s && dst_row != padding_idx)
90         my_s[threadIdx.x] = static_cast<accscalar_t>(grad[src_row*stride + f]);
91 
92       __syncthreads();
93 
94       // To ensure determinism, we can't just have each warp add its grad data to its dst_row.
95       // We need to check if any other warps pulled grad data targeting dst_row.
96       // If so, we elect the first warp in each matching group as the leader.
97       // Each leader warp serializes the accumulates targeting dst_row in shared memory,
98       // then finishes by adding the accumulated buffer to dst_row in grad_weight.
99       if(dst_row != padding_idx && src_row < n) // Per-warp exit condition, safe with ballot_sync
100       {
101         int match_found_this_thread = 0;
102         if(threadIdx.x < n_this_chunk)
103           match_found_this_thread = (dst_row == indices_batch[chunk_start - batch_start + threadIdx.x]);
104 #if defined(USE_ROCM)
105         unsigned long long int matchmask = WARP_BALLOT(match_found_this_thread);
106         int first_remaining_peer = __ffsll(matchmask) - 1;
107 #else
108         unsigned int matchmask = WARP_BALLOT(match_found_this_thread);
109         int first_remaining_peer = __ffs(matchmask) - 1;
110 #endif
111 
112         if(threadIdx.y == first_remaining_peer) // Nominate lowest-indexed warp as the leader
113         {
114           matchmask ^= (1 << first_remaining_peer);
115           while(matchmask)
116           {
117 #if defined(USE_ROCM)
118             first_remaining_peer = __ffsll(matchmask) - 1;
119 #else
120             first_remaining_peer = __ffs(matchmask) - 1;
121 #endif
122             my_s[threadIdx.x] += smem[threadIdx.x + C10_WARP_SIZE*first_remaining_peer];
123             matchmask ^= (1 << first_remaining_peer);
124           }
125           if(f < s)
126             grad_weight[dst_row*stride + f] += static_cast<scalar_t>(my_s[threadIdx.x]);
127         }
128       }
129     }
130   }
131 }
132 
133 
134 template <typename scalar_t, typename index_t>
embedding_backward_kernel(index_t * input,index_t * indices,scalar_t * grad_output,scalar_t * grad_weight,index_t * count,int64_t numel,int64_t stride,int padding_idx)135 __global__ void embedding_backward_kernel(
136   index_t* input, index_t* indices, scalar_t* grad_output, scalar_t* grad_weight,
137   index_t* count, int64_t numel, int64_t stride, int padding_idx) {
138 
139   using accscalar_t = acc_type<scalar_t, true>;
140   int idx = blockIdx.x * 4 + threadIdx.y;
141 
142   // Each warp is responsible for an input into the LookupTable.
143   // If the preceding input has the same as this input, then the warp
144   // exits immediately. The warp also processes subsequent inputs with the
145   // same value.
146   //
147   // Input Warp
148   // 1     <warp 1>
149   // 1     <warp 1> (<warp 2> exits without doing any work)
150   // 5     <warp 3>
151   // 8     <warp 4>
152 
153   // Number of values processed by each thread (grain size)
154   const int SZ = 4;
155 
156   if (idx < numel
157       && (idx == 0 || input[idx] != input[idx - 1])
158       && input[idx] != padding_idx) {
159     do {
160       const int start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
161       const int weight_row = ((int) input[idx]) * stride;
162       const int grad_row = ((int) indices[idx]) * stride;
163       const accscalar_t scale = count ? (accscalar_t)1.0 / count[idx] : 1.0;
164 
165       accscalar_t gradient[SZ];
166       accscalar_t weight[SZ];
167 
168       #pragma unroll
169       for (int ii = 0; ii < SZ; ii++) {
170         int feature_dim = start_feature + ii * C10_WARP_SIZE;
171         if (feature_dim < stride) {
172           gradient[ii] = static_cast<accscalar_t>(grad_output[grad_row + feature_dim]);
173           weight[ii] = static_cast<accscalar_t>(grad_weight[weight_row + feature_dim]);
174         }
175       }
176 
177       #pragma unroll
178       for (int ii = 0; ii < SZ; ii++) {
179         weight[ii] += gradient[ii] * scale;
180       }
181 
182       #pragma unroll
183       for (int ii = 0; ii < SZ; ii++) {
184         int feature_dim = start_feature + ii * C10_WARP_SIZE;
185         if (feature_dim < stride) {
186             grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(weight[ii]);
187         }
188       }
189 
190       idx++;
191     } while (idx < numel && input[idx] == input[idx - 1]);
192   }
193 }
194 
195 /* Calculate norms of the rows of weight_ptr given by idx_ptr and capture them in norms */
196 template <typename scalar_t, typename accscalar_t, typename index_t>
renorm_kernel(scalar_t * weights,index_t * indices,accscalar_t max_norm,accscalar_t norm_type,int64_t dim,int64_t weights_stride0,int64_t weights_stride1,const int64_t * num_unique_indices)197 __global__ void renorm_kernel(
198     scalar_t* weights, index_t* indices, accscalar_t max_norm,
199     accscalar_t norm_type, int64_t dim,
200     int64_t weights_stride0, int64_t weights_stride1,
201     const int64_t *num_unique_indices) {
202   if (blockIdx.x >= *num_unique_indices) {
203     return;
204   }
205 
206   // Some casting hacks since dynamic shared memory and templates don't work together:
207   extern __shared__ unsigned char smem[];
208   auto sdata = reinterpret_cast<accscalar_t*>(smem);
209 
210   int tid = threadIdx.x;
211   int base_index = indices[blockIdx.x] * weights_stride0;
212 
213   accscalar_t v = 0;
214   for (int i = tid; i < dim; i += blockDim.x) {
215     auto x = static_cast<accscalar_t>(weights[base_index + i * weights_stride1]);
216     if (norm_type == 1) {
217       v += std::abs(x);
218     } else if (norm_type == 2) {
219       v += x * x;
220     } else {
221       v += std::pow(x, norm_type);
222     }
223   }
224 
225   v = cuda_utils::BlockReduceSum(v, sdata);
226 
227   if (tid == 0) {
228     sdata[0] = std::pow(v, static_cast<accscalar_t>(1.0 / norm_type));
229   }
230   __syncthreads();
231 
232   // now we renormalize the blocks that need it
233   if (sdata[0] > max_norm) {
234     auto factor = static_cast<scalar_t>(max_norm / (sdata[0] + 1e-7));
235     for (int i = tid; i < dim; i += blockDim.x) {
236       weights[base_index + i * weights_stride1] *= factor;
237     }
238   }
239 }
240 
241 } // anonymous namespace
242 
243 #if !CUB_SUPPORTS_SCAN_BY_KEY()
244 template<typename index_t>
245 void embedding_dense_backward_cuda_scan(Tensor &sorted_indices, Tensor &count);
246 #endif
247 
embedding_dense_backward_cuda(const Tensor & grad_,const Tensor & indices_,int64_t num_weights,int64_t padding_idx,bool scale_grad_by_freq)248 Tensor embedding_dense_backward_cuda(const Tensor & grad_, const Tensor & indices_,
249                                int64_t num_weights, int64_t padding_idx,
250                                bool scale_grad_by_freq) {
251   auto grad_arg = TensorArg(grad_, "grad", 1);
252   auto indices_arg = TensorArg(indices_, "indices", 1);
253   checkScalarTypes("embedding_backward", indices_arg, {kLong, kInt});
254   checkSameGPU("embedding_backward", grad_arg, indices_arg);
255 
256   auto indices = indices_.contiguous();
257 
258   auto num_indices = indices.numel();
259   auto grad = grad_.contiguous().view({num_indices, grad_.size(-1)});
260   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
261 
262   if (num_indices <= 3072 && !scale_grad_by_freq) {
263     auto indices_contig = indices.contiguous();
264     auto grad_weight = at::zeros({num_weights, grad_.size(-1)}, grad_.options());
265     int64_t stride = grad_weight.stride(0);
266     int warp_size = at::cuda::warp_size();
267     dim3 grid(ceil_div(stride, (int64_t)warp_size));
268     dim3 block(warp_size, BLOCKDIMY);
269 
270     AT_DISPATCH_FLOATING_TYPES_AND2(
271       at::ScalarType::Half, at::ScalarType::BFloat16,
272       grad.scalar_type(),
273        "embedding_backward",
274        [&]
275        {
276           using accscalar_t = acc_type<scalar_t, true>;
277           AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
278           embedding_backward_feature_kernel<scalar_t, accscalar_t, index_t>
279             <<<grid,
280                 block,
281                 sizeof(accscalar_t)*warp_size*BLOCKDIMY + sizeof(int)*warp_size*BLOCKDIMY,
282                 stream>>>
283             (indices_contig.const_data_ptr<index_t>(),
284               grad.const_data_ptr<scalar_t>(),
285               grad_weight.mutable_data_ptr<scalar_t>(),
286               static_cast<int>(num_indices),
287               static_cast<int64_t>(stride),
288               static_cast<int>(padding_idx));
289           C10_CUDA_KERNEL_LAUNCH_CHECK();
290           });
291        });
292     return grad_weight;
293   }
294 
295   auto sorted_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
296   auto orig_indices = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
297   Tensor count;
298   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
299     auto range = at::arange(num_indices, indices.options());
300     int64_t nbits = cuda::cub::get_num_bits(num_weights);
301     cuda::cub::radix_sort_pairs(
302       indices.const_data_ptr<index_t>(), sorted_indices.mutable_data_ptr<index_t>(),
303       range.const_data_ptr<index_t>(), orig_indices.mutable_data_ptr<index_t>(),
304       num_indices, false/*, 0, nbits*/);
305   });
306 
307   if (scale_grad_by_freq) {
308     count = at::empty_like(indices, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
309 #if CUB_SUPPORTS_SCAN_BY_KEY()
310     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
311       cudaStream_t stream = at::cuda::getCurrentCUDAStream();
312 
313       // Compute an increasing sequence per unique item in sortedIndices:
314       // sorted: 2 5 5 5 7 7 8 9 9
315       //  count: 1 1 2 3 1 2 1 1 2
316       auto sorted_data = sorted_indices.const_data_ptr<index_t>();
317       auto count_data = count.mutable_data_ptr<index_t>();
318       cuda::cub::inclusive_sum_by_key(
319         sorted_data,
320         at_cuda_detail::cub::ConstantInputIterator<index_t>(1),
321         count_data,
322         num_indices
323       );
324 
325       // Take the maximum of each count per unique key in reverse:
326       // sorted: 2 5 5 5 7 7 8 9 9
327       //  count: 1 3 3 3 2 2 1 2 2
328       cuda::cub::inclusive_scan_by_key(
329         thrust::make_reverse_iterator(sorted_data + num_indices),
330         thrust::make_reverse_iterator(static_cast<const index_t*>(count_data) + num_indices),
331         thrust::make_reverse_iterator(count_data + num_indices),
332         at_cuda_detail::cub::Max(),
333         num_indices
334       );
335     });
336 #else
337     AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_dense_backward_cuda", [&] () {
338       embedding_dense_backward_cuda_scan<index_t>(sorted_indices, count);
339     });
340 #endif
341   }
342 
343   return embedding_backward_cuda_kernel(grad, orig_indices,
344       sorted_indices, count, num_weights, padding_idx);
345 }
346 
embedding_renorm_cuda_(Tensor & self,const Tensor & indices,double max_norm,double norm_type)347 Tensor & embedding_renorm_cuda_(Tensor & self, const Tensor & indices,
348                                 double max_norm, double norm_type) {
349   auto self_arg = TensorArg(self, "self", 1);
350   auto indices_arg = TensorArg(indices, "indices", 1);
351   checkDim("embedding_renorm_", self_arg, 2);
352   checkSameGPU("embedding_renorm", self_arg, indices_arg);
353 
354   cudaStream_t stream = at::cuda::getCurrentCUDAStream();
355 
356   AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "embedding_renorm_cuda_", [&] () {
357 
358     auto num_indices = indices.numel();
359     auto indices_contig = std::get<0>(indices.sort()).contiguous();
360     auto unique_indices = at::empty(indices.numel(), indices.options());
361     auto num_unique_indices = at::empty({}, indices.options().dtype(kLong));
362 
363     cuda::cub::unique(
364       indices_contig.const_data_ptr<index_t>(),
365       unique_indices.mutable_data_ptr<index_t>(),
366       num_unique_indices.mutable_data_ptr<int64_t>(),
367       num_indices
368     );
369 
370     int warp_size = at::cuda::warp_size();
371     TORCH_INTERNAL_ASSERT(num_threads() % warp_size == 0 &&
372                   num_threads() <= cuda_utils::kCUDABlockReduceMaxThreads,
373                   "BlockReduceSum requires all warps be active");
374     const int64_t *num_unique_indices_ptr = num_unique_indices.const_data_ptr<int64_t>();
375     dim3 grid = unique_indices.numel();
376     dim3 block = num_threads();
377     int dim = self.stride(0);
378 
379     AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "embedding_renorm_cuda_", [&] {
380       using accscalar_t = acc_type<scalar_t, true>;
381       renorm_kernel<<<grid, block, (block.x / warp_size) * sizeof(accscalar_t), stream>>>(
382         self.mutable_data_ptr<scalar_t>(),
383         unique_indices.const_data_ptr<index_t>(),
384         static_cast<accscalar_t>(max_norm),
385         static_cast<accscalar_t>(norm_type),
386         dim, self.stride(0), self.stride(1),
387         num_unique_indices_ptr);
388       C10_CUDA_KERNEL_LAUNCH_CHECK();
389     });
390   });
391   return self;
392 }
393 
394 
395 }  // namespace at::native
396