xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Indexing.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/TensorAdvancedIndexing.h>
3 #include <ATen/native/IndexingUtils.h>
4 #include <ATen/native/quantized/IndexKernel.h>
5 #include <ATen/native/cuda/KernelUtils.cuh>
6 
7 #include <ATen/core/Tensor.h>
8 #include <ATen/ceil_div.h>
9 #include <ATen/Dispatch.h>
10 #include <ATen/Dispatch_v2.h>
11 #include <ATen/ExpandUtils.h>
12 #include <ATen/MemoryOverlap.h>
13 #include <ATen/TensorOperators.h>
14 #include <ATen/native/TensorIterator.h>
15 #include <ATen/native/cuda/Loops.cuh>
16 #include <ATen/native/Resize.h>
17 #include <ATen/cuda/detail/IndexUtils.cuh>
18 #include <ATen/cuda/CUDAUtils.h>
19 #include <ATen/cuda/DeviceUtils.cuh>
20 
21 #ifndef AT_PER_OPERATOR_HEADERS
22 #include <ATen/Functions.h>
23 #include <ATen/NativeFunctions.h>
24 #else
25 #include <ATen/ops/_assert_async.h>
26 #include <ATen/ops/arange.h>
27 #include <ATen/ops/empty.h>
28 #include <ATen/ops/zeros_like.h>
29 #include <ATen/ops/ones_like.h>
30 #include <ATen/ops/empty_quantized.h>
31 #include <ATen/ops/index_add_native.h>
32 #include <ATen/ops/index_reduce_native.h>
33 #include <ATen/ops/index_select_native.h>
34 #include <ATen/ops/masked_fill_native.h>
35 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
36 #endif
37 
38 #include <ATen/cuda/CUDAContext.h>
39 #include <ATen/cuda/cub.h>
40 #include <c10/util/irange.h>
41 #include <c10/core/QScheme.h>
42 #include <ATen/native/quantized/AffineQuantizerBase.h>
43 
44 #include <limits>
45 
46 #include <c10/macros/Macros.h>
47 
48 namespace {
49 template <typename scalar_t, int SZ>
indexing_backward_kernel(const int64_t * sorted_indices,const int64_t * indices,const scalar_t * grad_output,scalar_t * grad_weight,int64_t numel,int64_t stride,int64_t stride_before,int64_t outer_dim,bool accumulate)50 __global__ void indexing_backward_kernel(
51   const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
52   int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
53 //numel is total number of flattened indices, not expanded to dimensions that are not indexed.
54 //stride is the cumulative size of the not-indexed last dimensions
55 //stride_before is the stride of the dimension immediately preceding first indexed dimension
56 //if indexing starts from the 0th dimension, stride_before does not matter because blockIdx.z will be 0 in this case
57 //outer_dim is number of elements in the first unindexed dimensions
58   using opmath_t = at::opmath_type<scalar_t>;
59 
60   // Each warp is responsible for an input into the LookupTable.
61   // If the preceding input has the same destination index as this input, then the warp
62   // exits immediately. The warp also processes subsequent inputs with the
63   // same value.
64   //
65   // Input Warp
66   // 1     <warp 1>
67   // 1     <warp 1> (<warp 2> exits without doing any work)
68   // 5     <warp 3>
69   // 8     <warp 4>
70 
71   // Number of values processed by each thread (grain size)
72   for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
73     int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
74     if (idx < numel
75         && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
76       do {
77         int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
78         // if not accumulate, we only keep the last duplicate index so skip those before it
79         if (!accumulate && (idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
80           idx++;
81           continue;
82         }
83         const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
84         const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
85         const opmath_t scale = (opmath_t)1.0;
86 
87         opmath_t gradient[SZ];
88         opmath_t weight[SZ];
89 
90         while (start_feature < stride) {
91           #pragma unroll
92           for (int ii = 0; ii < SZ; ii++) {
93             int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
94             if (feature_dim < stride) {
95               gradient[ii] = static_cast<opmath_t>(grad_output[grad_row + feature_dim]);
96               if (accumulate) {
97                 weight[ii] = static_cast<opmath_t>(grad_weight[weight_row + feature_dim]);
98               }
99             }
100           }
101 
102           #pragma unroll
103           for (int ii = 0; ii < SZ; ii++) {
104             if (accumulate) {
105               weight[ii] += gradient[ii] * scale;
106             } else {
107               weight[ii] = gradient[ii] * scale;
108             }
109           }
110 
111           #pragma unroll
112           for (int ii = 0; ii < SZ; ii++) {
113             int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
114             if (feature_dim < stride) {
115                 grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(weight[ii]);
116             }
117           }
118           start_feature += gridDim.y * blockDim.x * SZ;
119         }
120 
121         idx++;
122       } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]);
123     }
124   }
125 }
126 
127 template <typename scalar_t>
indexing_backward_kernel_stride_1(const int64_t * sorted_indices,const int64_t * indices,const scalar_t * grad_output,scalar_t * grad_weight,int64_t numel,int64_t stride,int64_t stride_before,int64_t outer_dim,bool accumulate)128 __global__ void indexing_backward_kernel_stride_1(
129   const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
130   int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
131   using opmath_t = at::opmath_type<scalar_t>;
132 
133   // Number of values processed by each thread (grain size)
134   for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
135     int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
136     int64_t crnt_sorted_idx = sorted_indices[idx];
137 
138     if ((idx < numel) &&
139         (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]))
140     {
141       // Determine the number of duplicates in advance
142       int64_t num_duplicates = 1;
143       while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
144         num_duplicates++;
145       }
146 
147       // Continue computing weights
148       const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
149       int64_t grad_row = 0;
150       const opmath_t scale = (opmath_t)1.0;
151 
152       if (!accumulate) {
153         grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
154         grad_weight[weight_row] =
155           static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
156       } else {
157         opmath_t gradient = (opmath_t)0.0;
158 
159         int laneIdx = threadIdx.x % C10_WARP_SIZE;
160         int64_t num_warp_passes = num_duplicates / C10_WARP_SIZE;
161         for (int64_t i = 0; i < num_warp_passes; ++i) {
162             grad_row = ((int64_t) indices[idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
163             gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
164         }
165         WARP_SYNC();
166         for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
167           gradient += WARP_SHFL_DOWN(gradient, offset);
168         }
169 
170         if (laneIdx == 0) {
171           for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < num_duplicates; ++i) {
172             grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
173             gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
174           }
175 
176           grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
177         }
178       }
179     }
180   }
181 }
182 
183 template <typename scalar_t>
indexing_backward_kernel_small_stride(const int64_t * sorted_indices,const int64_t * indices,const scalar_t * grad_output,scalar_t * grad_weight,int64_t numel,int64_t stride,int64_t stride_before,int64_t outer_dim,bool accumulate)184 __global__ void indexing_backward_kernel_small_stride(
185   const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
186   int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
187   using opmath_t = at::opmath_type<scalar_t>;
188 
189   // Number of values processed by each thread (grain size)
190   for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
191     int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
192     int64_t tidx = threadIdx.x;
193     int64_t crnt_sorted_idx = sorted_indices[idx];
194 
195     if ((idx < numel) &&
196         (tidx < stride) &&
197         (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]))
198     {
199       // Determine the number of duplicates in advance
200       int64_t num_duplicates = 1;
201       while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
202         num_duplicates++;
203       }
204 
205       // Continue computing weights
206       const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
207       int64_t grad_row = 0;
208       const opmath_t scale = (opmath_t)1.0;
209 
210       if (!accumulate) {
211         grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
212         grad_weight[weight_row + tidx] =
213           static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row + tidx]) * scale);
214       } else {
215         opmath_t gradient = (opmath_t)0.0;
216         for (int64_t i = 0; i < num_duplicates; ++i) {
217           grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
218           gradient += static_cast<opmath_t>(grad_output[grad_row + tidx]) * scale;
219         }
220 
221         grad_weight[weight_row + tidx] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row + tidx]) + gradient);
222       }
223     }
224   }
225 }
226 
227 template <typename scalar_t, int SZ>
indexing_backward_kernel_quantized(const int64_t * sorted_indices,const int64_t * indices,const float * grad_output,scalar_t * grad_weight,int64_t numel,int64_t stride,int64_t stride_before,int64_t outer_dim,float inv_scale,int zero_point,int64_t qmin,int64_t qmax)228 __global__ void indexing_backward_kernel_quantized(
229   const int64_t* sorted_indices, const int64_t* indices, const float* grad_output, scalar_t* grad_weight,
230   int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim,
231   float inv_scale, int zero_point, int64_t qmin, int64_t qmax) {
232 
233   // This implementation is adopted from indexing_backward_kernel above.
234   using opmath_t = at::opmath_type<float>;
235   for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
236     int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
237     if (idx < numel
238         && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
239       do {
240         int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
241         // we only keep the last duplicate index so skip those before it
242         if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
243           idx++;
244           continue;
245         }
246         const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
247         const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
248         const opmath_t scale = (opmath_t)1.0;
249 
250         opmath_t gradient[SZ];
251         opmath_t weight[SZ];
252 
253         while (start_feature < stride) {
254           #pragma unroll
255           for (int ii = 0; ii < SZ; ii++) {
256             int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
257             if (feature_dim < stride) {
258               gradient[ii] = static_cast<opmath_t>(grad_output[grad_row + feature_dim]);
259             }
260           }
261 
262           #pragma unroll
263           for (int ii = 0; ii < SZ; ii++) {
264             weight[ii] = gradient[ii] * scale;
265           }
266 
267           #pragma unroll
268           for (int ii = 0; ii < SZ; ii++) {
269             int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
270             if (feature_dim < stride) {
271                 // we do quantization here
272                 int64_t qvalue = static_cast<int64_t>(zero_point + nearbyintf(weight[ii]* inv_scale));
273                 qvalue = min(max(qvalue, qmin), qmax);
274                 grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(qvalue);
275             }
276           }
277           start_feature += gridDim.y * blockDim.x * SZ;
278         }
279 
280         idx++;
281       } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]);
282     }
283   }
284 }
285 
286 
287 }
288 
289 
290 namespace at::native {
291 
292 namespace {
293 
294 class ReduceMultiply {
295 public:
296   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const297   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
298     (void)numel; // suppress unused warning
299     gpuAtomicMul(self_data_start + index, *src_data);
300   }
301 };
302 static ReduceMultiply reduce_multiply;
303 
304 class ReduceAdd {
305 public:
306   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const307   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
308     fastAtomicAdd(self_data_start, index, numel, *src_data, true);
309   }
310 };
311 static ReduceAdd reduce_add;
312 
313 class ReduceMinimum {
314 public:
315   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const316   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
317     (void)numel; // suppress unused warning
318     gpuAtomicMin(self_data_start + index, *src_data);
319   }
320 };
321 static ReduceMinimum reduce_minimum;
322 
323 class ReduceMaximum {
324 public:
325   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const326   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
327     (void)numel; // suppress unused warning
328     gpuAtomicMax(self_data_start + index, *src_data);
329   }
330 };
331 static ReduceMaximum reduce_maximum;
332 
333 }
334 
wrapIndexOnce(const Tensor & index,int64_t dim,int64_t dim_size,bool check_range=true)335 static Tensor wrapIndexOnce(const Tensor & index, int64_t dim, int64_t dim_size, bool check_range=true) {
336 //we don't need to check range in backward - if there were out of bounds indices forward should already have errored out
337   if (index.numel() != 0 && check_range) {
338     at::_assert_async(index.max() < dim_size);
339     at::_assert_async(index.min() >= -dim_size);
340   }
341   return index.remainder(dim_size);
342 }
343 
computeLinearStride(const Tensor & tensor)344 static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
345   // computes the stride as if tensor were contiguous
346   auto sizes = tensor.sizes();
347   std::vector<int64_t> stride(tensor.dim());
348   if (stride.empty()) {
349     return stride;
350   }
351   stride[tensor.dim() - 1] = 1;
352   std::partial_sum(sizes.rbegin(), sizes.rend() - 1, stride.rbegin() + 1, std::multiplies<int64_t>());
353   return stride;
354 }
355 
356 static std::tuple<Tensor, int64_t, int64_t, int64_t>
computeLinearIndex(const Tensor & src,TensorList indices,bool check_range)357 computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
358   auto strides = computeLinearStride(src);
359   const auto& device = src.options().device();
360 
361   // Compute the linear index by multiplying the indexing tensors by the
362   // stride and summing them. All the indexing tensors have the same shape at
363   // this point. We also compute the number of dimensions before and after that
364   // are not being index.
365   Tensor linearIndex;
366   int64_t nElemBefore = 1, nElemAfter = 1, strideBefore =0;
367   for (const auto i: c10::irange(src.dim())) {
368     if (indices[i].defined()) {
369       // Cast index to the longType matching src's device
370       // This allows us to support ie indexing a cuda tensor with a cpu tensor
371       Tensor index = (wrapIndexOnce(indices[i], i, src.size(i), check_range) * strides[i]).to(device);
372       if (linearIndex.defined()) {
373         linearIndex += index;
374       } else {
375         linearIndex = index;
376         if (i>0) {
377            strideBefore = src.stride(i-1); // stride after undefined dimensions
378         }
379       }
380     } else if (linearIndex.defined()) {
381       nElemAfter *= src.size(i);
382     } else {
383       nElemBefore *= src.size(i);
384     }
385   }
386 
387   return std::make_tuple(std::move(linearIndex), nElemBefore, strideBefore, nElemAfter);
388 }
389 
390 
makeLinearIndex(Tensor self,IOptTensorListRef orig,bool check_range)391 static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
392   checkIndexTensorTypes(orig, /*allow_int*/true);
393   // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
394   auto indices = expandTensors(self, orig);
395   for (auto & i : indices) {
396     if (i.defined() && i.dtype() == at::kInt) {
397       i = i.to(at::kLong);
398     }
399   }
400   // next broadcast all index tensors together
401   indices = expand_outplace(indices);
402   // add missing null Tensors so that it matches self.dim()
403   while (indices.size() < (size_t)self.dim()) {
404     indices.emplace_back();
405   }
406   // if the non-null indices are not all adjacent, transpose self and indices
407   // together so that they're adjacent at the front
408   std::vector<int64_t> inversePerm;
409   if (!hasContiguousSubspace(indices)) {
410     std::tie(self, indices, inversePerm) = transposeToFrontAndInvPerm(self, indices);
411   }
412   auto [linearIndex, nElemBefore, strideBefore, nElemAfter] = computeLinearIndex(self, indices, check_range);
413   return std::make_tuple(linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm);
414 }
415 
416 
417 void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
418 
419 namespace {
420 
largestIndex(const Tensor & self)421 int64_t largestIndex(const Tensor &self) {
422   int64_t result = 0;
423   for (const auto i: c10::irange(self.dim())) {
424     result += (self.sizes()[i] - 1) * self.strides()[i];
425   }
426   return result;
427 }
428 
index_put_with_sort_kernel(Tensor & self,const c10::List<std::optional<Tensor>> & indices,const Tensor & value,bool accumulate,bool unsafe)429 void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
430   TORCH_CHECK(!indices.empty() || is_expandable_to(value.sizes(), self.sizes()), "shape mismatch: value tensor of shape ", value.sizes(),
431              " cannot be broadcast to indexing result of shape ", self.sizes());
432   if (indices.size() > (size_t)self.dim()) {
433     TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
434   }
435   bool self_contiguous = self.is_contiguous();
436   auto self_ = self_contiguous ? self : self.contiguous();
437   Tensor linearIndex, src, expandedValue = value;
438   int64_t nElemBefore, strideBefore, sliceSize;
439   std::vector<int64_t> inversePerm;
440   std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self_, indices, !unsafe);
441   int64_t num_indices = linearIndex.numel();
442 
443   if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
444     auto expanded_size = at::DimVector(expandedValue.sizes());
445     auto size1 = expandedValue.sizes();
446     auto size2 = linearIndex.sizes();
447     if (are_expandable(size1, size2)) {
448       expanded_size = infer_size_dimvector(size1, size2);
449     }
450     if (nElemBefore > 1) {
451       expanded_size.insert(expanded_size.begin(), nElemBefore);
452     }
453     if (sliceSize > 1) {
454       expanded_size.insert(expanded_size.end(), sliceSize);
455     }
456     expandedValue = expandedValue.expand(expanded_size);
457   }
458   expandedValue = expandedValue.contiguous();
459 
460   if (num_indices > 0 && sliceSize > 0) {
461       const bool permuted = !src.is_contiguous();
462       auto src_ = permuted ? src.contiguous() : src;
463       linearIndex = linearIndex.reshape(-1);
464       auto sorted_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
465       auto orig_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
466       const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
467 
468       linearIndex.divide_(sliceSize, "trunc");
469 
470       // cub on CUDA <= 11.2 have a bug that for small sizes
471       // cub's sort can be much slower than thrust's merge sort
472       // this bug is fixed in CUDA 11.3
473 #if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) || defined(USE_ROCM)
474       if (num_indices < 50000) {
475         index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
476       } else
477 #endif
478       {
479       // Sort the inputs into sorted with the corresponding indices
480       auto range = at::arange(num_indices, linearIndex.options());
481       // linearIndex can not be negative, and we take advantage of this
482       // fact to sort on less bits for better performance.
483       int64_t nbits = cuda::cub::get_num_bits(largestIndex(self_) / sliceSize);
484       cuda::cub::radix_sort_pairs(
485         linearIndex.const_data_ptr<int64_t>(), sorted_indices.mutable_data_ptr<int64_t>(),
486         range.const_data_ptr<int64_t>(), orig_indices.mutable_data_ptr<int64_t>(),
487         num_indices, false, 0, nbits);
488       }
489 
490       TORCH_INTERNAL_ASSERT(
491           linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),
492           "number of flattened indices did not match number of elements in the value tensor: ",
493           linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel());
494       const int UNROLL = 4;
495       const int indices_per_block = 4;
496       const int warp_size = at::cuda::warp_size();
497       dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block),
498            std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size*UNROLL))),
499            std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
500       dim3 block(warp_size, indices_per_block);
501 
502 
503       if (sliceSize == 1) {
504         // This implementation is faster with high amounts of duplicates but could overflow
505         // if FP16 / BF16 is used
506         AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
507         expandedValue.scalar_type(), "indexing_backward_kernel_stride_1", [&] {
508           indexing_backward_kernel_stride_1<scalar_t><<<grid, block, 0, stream>>>(
509             sorted_indices.const_data_ptr<int64_t>(),
510             orig_indices.const_data_ptr<int64_t>(),
511             expandedValue.const_data_ptr<scalar_t>(),
512             src_.mutable_data_ptr<scalar_t>(),
513             num_indices,
514             sliceSize,
515             strideBefore,
516             nElemBefore,
517             accumulate);
518           C10_CUDA_KERNEL_LAUNCH_CHECK();
519         });
520       } else {
521         if (sliceSize <= warp_size) {
522           AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
523           expandedValue.scalar_type(), "indexing_backward_kernel_small_stride", [&] {
524             indexing_backward_kernel_small_stride<scalar_t><<<grid, block, 0, stream>>>(
525               sorted_indices.const_data_ptr<int64_t>(),
526               orig_indices.const_data_ptr<int64_t>(),
527               expandedValue.const_data_ptr<scalar_t>(),
528               src_.mutable_data_ptr<scalar_t>(),
529               num_indices,
530               sliceSize,
531               strideBefore,
532               nElemBefore,
533               accumulate);
534             C10_CUDA_KERNEL_LAUNCH_CHECK();
535             });
536         } else {
537             AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
538             expandedValue.scalar_type(), "indexing_backward", [&] {
539               indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
540                 sorted_indices.const_data_ptr<int64_t>(),
541                 orig_indices.const_data_ptr<int64_t>(),
542                 expandedValue.const_data_ptr<scalar_t>(),
543                 src_.mutable_data_ptr<scalar_t>(),
544                 num_indices,
545                 sliceSize,
546                 strideBefore,
547                 nElemBefore,
548                 accumulate);
549               C10_CUDA_KERNEL_LAUNCH_CHECK();
550             });
551           }
552         }
553 
554       if (permuted) {
555         self.copy_(src_.permute(inversePerm));
556       } else if (!self_contiguous) {
557         self.copy_(self_);
558       }
559   }
560 }
561 
562 REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel);
563 
index_put_with_sort_quantized(Tensor & self,const c10::List<std::optional<Tensor>> & indices,const Tensor & value,double scale,int zero_point,bool unsafe)564 void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<Tensor>>& indices, const Tensor & value, double scale, int zero_point, bool unsafe) {
565   if (indices.size() > (size_t)self.dim()) {
566     TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
567   }
568   bool self_contiguous = self.is_contiguous();
569   auto self_ = self_contiguous ? self : self.contiguous();
570   Tensor linearIndex, src, expandedValue = value;
571   int64_t nElemBefore, strideBefore, sliceSize;
572   std::vector<int64_t> inversePerm;
573   std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self_, indices, !unsafe);
574   int64_t num_indices = linearIndex.numel();
575 
576   if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
577     auto expanded_size = at::DimVector(expandedValue.sizes());
578     auto size1 = expandedValue.sizes();
579     auto size2 = linearIndex.sizes();
580     if (are_expandable(size1, size2)) {
581       expanded_size = infer_size_dimvector(size1, size2);
582     }
583     if (nElemBefore > 1) {
584       expanded_size.insert(expanded_size.begin(), nElemBefore);
585     }
586     expandedValue = expandedValue.expand(expanded_size);
587   }
588   expandedValue = expandedValue.contiguous();
589 
590   if (num_indices > 0 && sliceSize > 0) {
591       const bool permuted = !src.is_contiguous();
592       auto src_ = permuted ? src.contiguous() : src;
593       linearIndex = linearIndex.reshape(-1);
594       auto sorted_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
595       auto orig_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
596       const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
597 
598       linearIndex.divide_(sliceSize, "trunc");
599 
600       // cub on CUDA <= 11.2 have a bug that for small sizes
601       // cub's sort can be much slower than thrust's merge sort
602       // this bug is fixed in CUDA 11.3
603 #if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) || defined(USE_ROCM)
604       if (num_indices < 50000) {
605         index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
606       } else
607 #endif
608       {
609       // Sort the inputs into sorted with the corresponding indices
610       auto range = at::arange(num_indices, linearIndex.options());
611       // linearIndex can not be negative, and we take advantage of this
612       // fact to sort on less bits for better performance.
613       int64_t nbits = cuda::cub::get_num_bits(largestIndex(self_) / sliceSize);
614       cuda::cub::radix_sort_pairs(
615         linearIndex.const_data_ptr<int64_t>(), sorted_indices.mutable_data_ptr<int64_t>(),
616         range.const_data_ptr<int64_t>(), orig_indices.mutable_data_ptr<int64_t>(),
617         num_indices, false, 0, nbits);
618       }
619 
620       TORCH_INTERNAL_ASSERT(
621           linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),
622           "number of flattened indices did not match number of elements in the value tensor: ",
623           linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel());
624       const int UNROLL = 4;
625       const int indices_per_block = 4;
626       const int warp_size = at::cuda::warp_size();
627       dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block),
628            std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size*UNROLL))),
629            std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
630       dim3 block(warp_size, indices_per_block);
631 
632       AT_DISPATCH_QINT_TYPES(
633         src.scalar_type(), "indexing_backward_quantized", [&] {
634         constexpr int64_t qmin = std::numeric_limits<typename scalar_t::underlying>::min();
635         constexpr int64_t qmax = std::numeric_limits<typename scalar_t::underlying>::max();
636         float inv_scale = 1.0f / static_cast<float>(scale);
637 
638         indexing_backward_kernel_quantized<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
639           sorted_indices.const_data_ptr<int64_t>(),
640           orig_indices.const_data_ptr<int64_t>(),
641           expandedValue.const_data_ptr<float>(),
642           src_.mutable_data_ptr<scalar_t>(),
643           num_indices,
644           sliceSize,
645           strideBefore,
646           nElemBefore,
647           inv_scale,
648           zero_point,
649           qmin,
650           qmax);
651         C10_CUDA_KERNEL_LAUNCH_CHECK();
652       });
653 
654       if (permuted) {
655         self.copy_(src_.permute(inversePerm));
656       } else if (!self_contiguous) {
657         self.copy_(self_);
658       }
659   }
660 }
661 
662 REGISTER_CUDA_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_quantized);
663 } //anonymous
664 
665 
666 // Check tensor dimensions for index operations, and return the slice size.
getSliceSize(const Tensor & dst,int dim,const Tensor & index,const Tensor & src)667 static size_t getSliceSize(const Tensor & dst,
668                               int dim,
669                               const Tensor & index,
670                               const Tensor & src)
671 {
672   const auto dstDims = dst.dim();
673   const auto srcDims = src.dim();
674 
675   TORCH_CHECK(index.dim() <= 1, "Index must be vector or scalar");
676 
677   size_t dstSliceSize = 1;
678   TORCH_CHECK(dim >= 0 && dim < dstDims, "Indexing dim ", dim, " is out of bounds");
679   for (const auto d: c10::irange(dstDims)) {
680     if (d != dim) {
681       dstSliceSize *= dst.size(d);
682     }
683   }
684 
685   TORCH_CHECK(dim < srcDims, "Indexing dim ", dim, " is out of bounds");
686   TORCH_CHECK(index.numel() == src.size(dim),
687              "length of src.size[dim] is not equal to length of indices");
688 
689   size_t srcSliceSize = 1;
690   bool mismatch = false;
691 
692   if (dstDims != srcDims) mismatch = true;
693 
694   for (const auto d: c10::irange(srcDims)) {
695     if (d != dim) {
696       srcSliceSize *= src.size(d);
697       if (!mismatch && dst.size(d) != src.size(d)) mismatch = true;
698     }
699   }
700 
701   TORCH_CHECK(dstSliceSize == srcSliceSize,
702              "Source/destination tensor have different slice sizes (%ld vs %ld)",
703              dstSliceSize, srcSliceSize);
704 
705   if (mismatch) {
706     TORCH_WARN_ONCE(
707         "Warning: source/destination slices have same size but different "
708         "shape for an index operation.  This behavior is deprecated.\n");
709   }
710 
711   return dstSliceSize;
712 }
713 
714 // We prefer this kernel to avoid reloading index points if the number
715 // of indices is a small number.
716 // This kernel in fact works for all choices of problem size, but if
717 // the number of indices chosen is large, then the
718 // indexFuncLargeIndex kernel is a better choice to increase
719 // parallelism.
720 template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
721           typename func_t>
indexFuncSmallIndex(cuda::detail::TensorInfo<T,IndexType> dst,cuda::detail::TensorInfo<const T,IndexType> src,cuda::detail::TensorInfo<const IndicesType,IndexType> indices,int dstAddDim,int srcAddDim,IndexType innerSize,int64_t dstAddDimSize,int64_t dstNumel,const func_t & op,T alpha)722 __global__ void indexFuncSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
723                                     cuda::detail::TensorInfo<const T, IndexType> src,
724                                     cuda::detail::TensorInfo<const IndicesType, IndexType> indices,
725                                     int dstAddDim,
726                                     int srcAddDim,
727                                     IndexType innerSize,
728                                     int64_t dstAddDimSize,
729                                     int64_t dstNumel,
730                                     const func_t& op,
731                                     T alpha) {
732   // In order to avoid reloading the index that we are copying, load
733   // it once to handle all of the points that are being selected, so
734   // it can be reused as much as possible. This kernel is chosen when
735   // this is a good choice (small number of chosen indices), since
736   // re-accessing indices in addition to src elements can be slow.
737   for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
738     // Lua indices begin at 1
739     IndexType dstIndex =
740         indices.data[cuda::detail::IndexToOffset<const IndicesType, IndexType, IdxDim>::get(srcIndex, indices)];
741     CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
742 
743     // We stride over the output ignoring the indexed dimension
744     // (innerSize), whose offset calculation is handled differently
745     for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
746          linearIndex < innerSize;
747          linearIndex += gridDim.x * blockDim.x) {
748       IndexType dstOffset =
749           cuda::detail::IndexToOffset<T, IndexType, DstDim>::get(linearIndex, dst);
750       dstOffset += dstIndex * dst.strides[dstAddDim];
751 
752       IndexType srcOffset =
753           cuda::detail::IndexToOffset<const T, IndexType, SrcDim>::get(linearIndex, src);
754       srcOffset += srcIndex * src.strides[srcAddDim];
755 
756       T val = src.data[srcOffset] * alpha;
757       op(dst.data, dstOffset, dstNumel, &val);
758     }
759 
760   }
761 }
762 
763 // We prefer this kernel to balance parallelism across index points,
764 // if there are a large number of indices.
765 // This kernel in fact works for all choices of problem size, but if
766 // the number of indices chosen is small, then the
767 // indexFuncSmallIndex kernel is a better choice to reduce memory
768 // accesses.
769 template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
770           bool IndexIsMajor, typename func_t>
indexFuncLargeIndex(cuda::detail::TensorInfo<T,IndexType> dst,cuda::detail::TensorInfo<const T,IndexType> src,cuda::detail::TensorInfo<const IndicesType,IndexType> indices,int dstAddDim,int srcAddDim,IndexType totalSize,IndexType innerSize,int64_t dstAddDimSize,int64_t dstNumel,const func_t & op,T alpha)771 __global__ void indexFuncLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
772                                     cuda::detail::TensorInfo<const T, IndexType> src,
773                                     cuda::detail::TensorInfo<const IndicesType, IndexType> indices,
774                                     int dstAddDim,
775                                     int srcAddDim,
776                                     IndexType totalSize,
777                                     IndexType innerSize,
778                                     int64_t dstAddDimSize,
779                                     int64_t dstNumel,
780                                     const func_t& op,
781                                     T alpha) {
782   // We stride over the output including the indexed dimension
783   // (totalSize), and calculate the destination index point based on that
784   for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
785        linearIndex < totalSize;
786        linearIndex += gridDim.x * blockDim.x) {
787     IndexType srcIndex, elementInSlice;
788     if (IndexIsMajor) {
789       srcIndex = linearIndex / innerSize;
790       elementInSlice = linearIndex % innerSize;
791     }
792     else {
793       elementInSlice = linearIndex / innerSize;
794       srcIndex = linearIndex % innerSize;
795     }
796 
797     // Lua indices begin at 1
798     IndexType dstIndex =
799         indices.data[cuda::detail::IndexToOffset<const IndicesType, IndexType, IdxDim>::get(srcIndex, indices)];
800     CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
801 
802     IndexType dstOffset =
803       cuda::detail::IndexToOffset<T, IndexType, DstDim>::get(elementInSlice, dst);
804     dstOffset += dstIndex * dst.strides[dstAddDim];
805 
806     IndexType srcOffset =
807       cuda::detail::IndexToOffset<const T, IndexType, SrcDim>::get(elementInSlice, src);
808     srcOffset += srcIndex * src.strides[srcAddDim];
809 
810     T val = src.data[srcOffset] * alpha;
811     op(dst.data, dstOffset, dstNumel, &val);
812   }
813 }
814 
815 // Compare the stride between adjacent slices (sliceStride) with strides in the
816 // other dimensions (i.e., strides *inside* each slice).
817 //
818 // - Returns true if some dimension inside the slice has lower stride than
819 //   sliceStride.  The simplest example is a 2-D contiguous tensor with sliceDim
820 //   == 0 (that is, each slice is a row).
821 //
822 //   In this case, we choose the CUDA kernel that processes the data in
823 //   "index-major order".  For example, if thread count equals slice size, then
824 //   all threads process slice #0 in lockstep, and then slice #1, and so on.
825 //
826 // - Otherwise (i.e., sliceStride has the lowest value), this function returns
827 //   false.  The simplest example is a 2-D contiguous tensor with sliceDim == 1
828 //   (each slice is a column).
829 //
830 //   In this case, we choose the CUDA kernel that processes the data in
831 //   "elementInSlice-major order".  For example, each thread can process element
832 //   #0 of every slice, and then element #1 of every slice, and so on.
833 template <typename scalar_t>
indexShouldBeMajor(cuda::detail::TensorInfo<scalar_t,unsigned int> & info,int sliceDim)834 bool indexShouldBeMajor(cuda::detail::TensorInfo<scalar_t, unsigned int> &info,
835                                     int sliceDim)
836 {
837   // The stride between adjacent slices (e.g., between element #0 of slice #100
838   // and element #0 of slice #101).
839   unsigned int sliceStride = info.strides[sliceDim];
840 
841   for (const auto i: c10::irange(info.dims)) {
842     if (i != sliceDim && info.sizes[i] > 1 && info.strides[i] < sliceStride) {
843       return true;
844     }
845   }
846 
847   return false;
848 }
849 
index_add_cuda_impl(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source,const Scalar & alpha,const Tensor & result)850 void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) {
851   if (!result.is_same(self)) {
852     result.copy_(self);
853   }
854 
855   // Scalars are treated as 1-d tensor
856   const Tensor self_ = (result.dim() == 0) ? result.view(1) : result;
857   const Tensor source_ = (source.dim() == 0) ? source.view(1) : source;
858 
859   TORCH_CHECK(result.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims");
860   TORCH_CHECK(source.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims" );
861   TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims");
862 
863   if (globalContext().deterministicAlgorithms()){
864     torch::List<std::optional<Tensor>> indices;
865     indices.reserve(dim + 1);
866     for (const auto i: c10::irange(dim)) {
867       indices.emplace_back();
868     }
869     indices.emplace_back(index.to(at::kLong));
870     result.index_put_(indices, source * alpha, true);
871     return;
872   }
873 
874   // The `source` is partitioned into two parts:
875   // -the size of each slice we are indexing, which is the
876   // total size of the tensor ignoring dimension `dim`;
877   // -the number of index we are choosing, which is the total size
878   // of the tensor `index`.
879   const uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
880   const uint64_t sourceTotalSize = source.numel();
881   const uint64_t selfAddDimSize = self_.size(dim);
882   const uint64_t numIndex = index.numel();
883   const uint64_t selfNumel = self_.numel();
884 
885   if (sliceSize == 0) {
886     return;
887   }
888   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
889   const bool indContig = index.is_contiguous();
890 
891   const int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
892 
893 #define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM)     \
894   indexFuncSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM>   \
895     <<<smallIndexGrid, smallIndexBlock, 0, stream>>>(                                   \
896       selfInfo, sourceInfo, indexInfo,                                                  \
897       selfAddDim, sourceAddDim, sliceSize, selfAddDimSize,                              \
898       selfNumel, reduce_add, alpha_value);                                              \
899   C10_CUDA_KERNEL_LAUNCH_CHECK();
900 
901 #define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE,                        \
902                     SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR)            \
903   indexFuncLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE,                      \
904                       SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR>          \
905     <<<largeIndexGrid, largeIndexBlock, 0, stream>>>(                       \
906       selfInfo, sourceInfo, indexInfo,                                      \
907       selfAddDim, sourceAddDim, sourceTotalSize,                            \
908       (IDX_IS_MAJOR) ? sliceSize : numIndex,                                \
909       selfAddDimSize, selfNumel, reduce_add, alpha_value);                  \
910   C10_CUDA_KERNEL_LAUNCH_CHECK();
911 
912   const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8)));
913   const dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
914 
915   const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8)));
916   const dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128));
917 
918   if (cuda::detail::canUse32BitIndexMath(result) &&
919       cuda::detail::canUse32BitIndexMath(source) &&
920       cuda::detail::canUse32BitIndexMath(index)) {
921     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::ComplexHalf, result.scalar_type(), "index_add", [&] {
922       cuda::detail::TensorInfo<scalar_t, unsigned int> selfInfo =
923           cuda::detail::getTensorInfo<scalar_t, unsigned int>(self_);
924       const int selfAddDim = selfInfo.collapseDims(dim);
925       selfInfo.reduceDim(selfAddDim);
926       const auto alpha_value = alpha.to<scalar_t>();
927       AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () {
928         auto sourceInfo =
929           cuda::detail::getTensorInfo<const scalar_t, unsigned int>(source_);
930         const int sourceAddDim = sourceInfo.collapseDims(dim);
931         sourceInfo.reduceDim(sourceAddDim);
932 
933         auto indexInfo =
934         cuda::detail::getTensorInfo<const index_t, unsigned int>(index);
935         indexInfo.collapseDims();
936 
937         // A reasonable choice for when to have each thread iterate over
938         // index to choose
939         if (numIndex <= 16) {
940           if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
941             SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
942           } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
943             SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
944           } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
945             SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
946           } else {
947             SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
948           }
949         } else {
950           const bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim);
951 
952           if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
953             LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
954           } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
955             if (indexIsMajor) {
956               LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
957             } else {
958               LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
959             }
960           } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
961             if (indexIsMajor) {
962               LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
963             } else {
964               LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
965             }
966           } else {
967             LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
968           }
969         }
970       });
971     });
972   } else {
973     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] {
974       cuda::detail::TensorInfo<scalar_t, uint64_t> selfInfo =
975         cuda::detail::getTensorInfo<scalar_t, uint64_t>(self_);
976       const int selfAddDim = selfInfo.collapseDims(dim);
977       selfInfo.reduceDim(selfAddDim);
978       const auto alpha_value = alpha.to<scalar_t>();
979 
980       cuda::detail::TensorInfo<const scalar_t, uint64_t> sourceInfo =
981         cuda::detail::getTensorInfo<const scalar_t, uint64_t>(source_);
982       const int sourceAddDim = sourceInfo.collapseDims(dim);
983       sourceInfo.reduceDim(sourceAddDim);
984 
985       AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () {
986         cuda::detail::TensorInfo<const index_t, uint64_t> indexInfo =
987           cuda::detail::getTensorInfo<const index_t, uint64_t>(index);
988         indexInfo.collapseDims();
989 
990         LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
991       });
992     });
993   }
994 
995 #undef SMALL_INDEX
996 #undef LARGE_INDEX
997 }
998 
999 template <typename func_t>
index_reduce_func_cuda_impl(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source,bool include_self,const ReductionType & reduce,const func_t & reduce_func,const Tensor & result)1000 void index_reduce_func_cuda_impl(
1001   const Tensor& self,
1002   int64_t dim,
1003   const Tensor& index,
1004   const Tensor& source,
1005   bool include_self,
1006   const ReductionType& reduce,
1007   const func_t& reduce_func,
1008   const Tensor& result) {
1009   globalContext().alertNotDeterministic("index_reduce_cuda");
1010 
1011   if (!result.is_same(self)) result.copy_(self);
1012 
1013   // Scalars are treated as 1-d tensor
1014   Tensor self_ = (result.dim() == 0) ? result.view(1) : result;
1015   Tensor source_ = (source.dim() == 0) ? source.view(1) : source;
1016 
1017   TORCH_CHECK(result.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims");
1018   TORCH_CHECK(source.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims" );
1019   TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims");
1020 
1021   if (!include_self) {
1022     AT_DISPATCH_ALL_TYPES_AND2(
1023       at::ScalarType::Half, at::ScalarType::BFloat16,
1024       self.scalar_type(), "index_reduce_func_cuda_exclude_input_init", [&] {
1025       scalar_t init_val;
1026       switch (reduce) {
1027         case ReductionType::PROD:
1028           init_val = (scalar_t)1;
1029           break;
1030         case ReductionType::MAX:
1031           init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
1032                      : std::numeric_limits<scalar_t>::lowest();
1033           break;
1034         case ReductionType::MIN:
1035           init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
1036                      : std::numeric_limits<scalar_t>::max();
1037           break;
1038         default:
1039           init_val = (scalar_t)0;
1040           break;
1041       }
1042       // index_fill_ requires index to be a LongTensor
1043       self_.index_fill_(dim, index.to(at::ScalarType::Long), init_val);
1044     });
1045   }
1046 
1047   // The `source` is partitioned into two parts:
1048   // -the size of each slice we are indexing, which is the
1049   // total size of the tensor ignoring dimension `dim`;
1050   // -the number of index we are choosing, which is the total size
1051   // of the tensor `index`.
1052   uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
1053   uint64_t sourceTotalSize = source.numel();
1054   uint64_t selfReduceDimSize = self_.size(dim);
1055   uint64_t numIndex = index.numel();
1056   uint64_t selfNumel = self_.numel();
1057 
1058   if (sliceSize == 0) {
1059     return;
1060   }
1061   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1062   bool indContig = index.is_contiguous();
1063 
1064   int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1065 
1066 #define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM)                  \
1067   indexFuncSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM>                \
1068     <<<smallIndexGrid, smallIndexBlock, 0, stream>>>(                                                \
1069       selfInfo, sourceInfo, indexInfo,                                                               \
1070       selfReduceDim, sourceReduceDim, sliceSize, selfReduceDimSize,                                  \
1071       selfNumel, reduce_func, alpha_value);                                                          \
1072   C10_CUDA_KERNEL_LAUNCH_CHECK();
1073 
1074 #define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE,                                     \
1075                     SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR)                         \
1076   indexFuncLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE,                                   \
1077                      SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR>                        \
1078     <<<largeIndexGrid, largeIndexBlock, 0, stream>>>(                                    \
1079       selfInfo, sourceInfo, indexInfo,                                                   \
1080       selfReduceDim, sourceReduceDim, sourceTotalSize,                                   \
1081       (IDX_IS_MAJOR) ? sliceSize : numIndex,                                             \
1082       selfReduceDimSize, selfNumel, reduce_func, alpha_value);                           \
1083   C10_CUDA_KERNEL_LAUNCH_CHECK();
1084 
1085   dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8)));
1086   dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
1087 
1088   dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8)));
1089   dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128));
1090 
1091   if (cuda::detail::canUse32BitIndexMath(result) &&
1092       cuda::detail::canUse32BitIndexMath(source) &&
1093       cuda::detail::canUse32BitIndexMath(index)) {
1094     AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "index_reduce", [&] {
1095       cuda::detail::TensorInfo<scalar_t, unsigned int> selfInfo =
1096           cuda::detail::getTensorInfo<scalar_t, unsigned int>(self_);
1097       int selfReduceDim = selfInfo.collapseDims(dim);
1098       selfInfo.reduceDim(selfReduceDim);
1099       auto alpha_value = (scalar_t) 1;
1100       AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_reduce_cuda", [&] () {
1101         auto sourceInfo =
1102           cuda::detail::getTensorInfo<const scalar_t, unsigned int>(source_);
1103         int sourceReduceDim = sourceInfo.collapseDims(dim);
1104         sourceInfo.reduceDim(sourceReduceDim);
1105 
1106         auto indexInfo =
1107         cuda::detail::getTensorInfo<const index_t, unsigned int>(index);
1108         indexInfo.collapseDims();
1109 
1110         // A reasonable choice for when to have each thread iterate over
1111         // index to choose
1112         if (numIndex <= 16) {
1113           if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
1114             SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
1115           } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
1116             SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
1117           } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
1118             SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
1119           } else {
1120             SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
1121           }
1122         } else {
1123           bool indexIsMajor = indexShouldBeMajor(selfInfo, selfReduceDim);
1124 
1125           if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
1126             LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
1127           } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
1128             if (indexIsMajor) {
1129               LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
1130             } else {
1131               LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
1132             }
1133           } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
1134             if (indexIsMajor) {
1135               LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
1136             } else {
1137               LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
1138             }
1139           } else {
1140             LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
1141           }
1142         }
1143       });
1144     });
1145   } else {
1146     AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_reduce", [&] {
1147       cuda::detail::TensorInfo<scalar_t, uint64_t> selfInfo =
1148         cuda::detail::getTensorInfo<scalar_t, uint64_t>(self_);
1149       int selfReduceDim = selfInfo.collapseDims(dim);
1150       selfInfo.reduceDim(selfReduceDim);
1151       auto alpha_value = (scalar_t) 1;
1152 
1153       cuda::detail::TensorInfo<const scalar_t, uint64_t> sourceInfo =
1154         cuda::detail::getTensorInfo<const scalar_t, uint64_t>(source_);
1155       int sourceReduceDim = sourceInfo.collapseDims(dim);
1156       sourceInfo.reduceDim(sourceReduceDim);
1157 
1158       AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_reduce_cuda", [&] () {
1159         cuda::detail::TensorInfo<const index_t, uint64_t> indexInfo =
1160           cuda::detail::getTensorInfo<const index_t, uint64_t>(index);
1161         indexInfo.collapseDims();
1162 
1163         LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
1164       });
1165     });
1166   }
1167 
1168 #undef SMALL_INDEX
1169 #undef LARGE_INDEX
1170 }
1171 
TORCH_IMPL_FUNC(index_add_cuda_out)1172 TORCH_IMPL_FUNC(index_add_cuda_out)
1173 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) {
1174   index_add_cuda_impl(self, dim, index, source, alpha, result);
1175 }
1176 
TORCH_IMPL_FUNC(index_reduce_cuda_out)1177 TORCH_IMPL_FUNC(index_reduce_cuda_out)
1178 (const Tensor& self,
1179  int64_t dim,
1180  const Tensor& index,
1181  const Tensor& source,
1182  const c10::string_view reduce,
1183  bool include_self,
1184  const Tensor& result) {
1185   TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
1186 
1187   if (reduce == "prod") {
1188     index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::PROD, reduce_multiply, result);
1189   } else if (reduce == "mean") {
1190     index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MEAN, reduce_add, result);
1191     auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
1192     counts.index_add_(dim, index, at::ones_like(source));
1193     counts.masked_fill_(counts == 0, 1);
1194     if (result.is_floating_point() || result.is_complex()) {
1195       result.div_(counts);
1196     } else {
1197       result.div_(counts, "floor");
1198     }
1199   } else if (reduce == "amax") {
1200     index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MAX, reduce_maximum, result);
1201   } else if (reduce == "amin") {
1202     index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MIN, reduce_minimum, result);
1203   } else {
1204     TORCH_CHECK(false, "reduce argument must be either prod, mean, amax or amin, got ", reduce, ".");
1205   }
1206 }
1207 
1208 namespace {
1209 // We prefer this kernel to avoid reloading index points if the number
1210 // of indices is a small number.
1211 // This kernel in fact works for all choices of problem size, but if
1212 // the number of indices chosen is large, then the
1213 // indexSelectLargeIndex kernel is a better choice to increase
1214 // parallelism.
1215 template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim>
indexSelectSmallIndex(cuda::detail::TensorInfo<T,IndexType> dst,cuda::detail::TensorInfo<const T,IndexType> src,cuda::detail::TensorInfo<const IndicesType,IndexType> indices,int dstSelectDim,int srcSelectDim,IndexType innerSize,int64_t srcSelectDimSize)1216 __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
1217                                       cuda::detail::TensorInfo<const T, IndexType> src,
1218                                       cuda::detail::TensorInfo<const IndicesType, IndexType> indices,
1219                                       int dstSelectDim,
1220                                       int srcSelectDim,
1221                                       IndexType innerSize,
1222                                       int64_t srcSelectDimSize) {
1223   // In order to avoid reloading the index that we are copying, load
1224   // it once to handle all of the points that are being selected, so
1225   // it can be reused as much as possible. This kernel is chosen when
1226   // this is a good choice (small number of chosen indices), since
1227   // re-accessing indices in addition to src elements can be slow.
1228   for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) {
1229     IndexType srcIndex =
1230       indices.data[cuda::detail::IndexToOffset<const IndicesType, IndexType, IdxDim>::get(dstIndex, indices)];
1231     CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
1232 
1233     // We stride over the output ignoring the indexed dimension
1234     // (innerSize), whose offset calculation is handled differently
1235     for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
1236          linearIndex < innerSize;
1237          linearIndex += gridDim.x * blockDim.x) {
1238       IndexType dstOffset =
1239         cuda::detail::IndexToOffset<T, IndexType, DstDim>::get(linearIndex, dst);
1240       dstOffset += dstIndex * dst.strides[dstSelectDim];
1241 
1242       IndexType srcOffset =
1243         cuda::detail::IndexToOffset<const T, IndexType, SrcDim>::get(linearIndex, src);
1244       srcOffset += srcIndex * src.strides[srcSelectDim];
1245 
1246       dst.data[dstOffset] = src.data[srcOffset];
1247     }
1248   }
1249 }
1250 
1251 // We prefer this kernel to balance parallelism across index points,
1252 // if there are a large number of indices.
1253 // This kernel in fact works for all choices of problem size, but if
1254 // the number of indices chosen is small, then the
1255 // indexSelectSmallIndex kernel is a better choice to reduce memory
1256 // accesses.
1257 template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
1258           bool IndexIsMajor>
indexSelectLargeIndex(cuda::detail::TensorInfo<T,IndexType> dst,cuda::detail::TensorInfo<const T,IndexType> src,cuda::detail::TensorInfo<const IndicesType,IndexType> indices,int dstSelectDim,int srcSelectDim,IndexType totalSize,IndexType innerSize,int64_t srcSelectDimSize)1259 __global__ void indexSelectLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
1260                                       cuda::detail::TensorInfo<const T, IndexType> src,
1261                                       cuda::detail::TensorInfo<const IndicesType, IndexType> indices,
1262                                       int dstSelectDim,
1263                                       int srcSelectDim,
1264                                       IndexType totalSize,
1265                                       IndexType innerSize,
1266                                       int64_t srcSelectDimSize) {
1267   // We stride over the output including the indexed dimension
1268   // (totalSize), and calculate the destination index point based on that
1269   for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
1270        linearIndex < totalSize;
1271        linearIndex += gridDim.x * blockDim.x) {
1272     IndexType dstIndex, elementInSlice;
1273     if (IndexIsMajor) {
1274       dstIndex = linearIndex / innerSize;
1275       elementInSlice = linearIndex % innerSize;
1276     }
1277     else {
1278       elementInSlice = linearIndex / innerSize;
1279       dstIndex = linearIndex % innerSize;
1280     }
1281 
1282     IndexType srcIndex =
1283       indices.data[cuda::detail::IndexToOffset<const IndicesType, IndexType, IdxDim>::get(dstIndex, indices)];
1284     CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
1285 
1286     IndexType dstOffset =
1287       cuda::detail::IndexToOffset<T, IndexType, DstDim>::get(elementInSlice, dst);
1288     dstOffset += dstIndex * dst.strides[dstSelectDim];
1289 
1290     IndexType srcOffset =
1291       cuda::detail::IndexToOffset<const T, IndexType, SrcDim>::get(elementInSlice, src);
1292     srcOffset += srcIndex * src.strides[srcSelectDim];
1293 
1294     dst.data[dstOffset] = src.data[srcOffset];
1295   }
1296 }
1297 
1298 namespace {
1299 
1300 // When using a 0-dim scalar tensor, we need the legacy (THC) semantics of
1301 // TensorInfo: Pretend that the scalar tensor is in fact a one-element vector.
1302 template <typename T, typename IndexType>
1303 cuda::detail::TensorInfo<T, IndexType>
tensorInfoLegacyIfScalar(cuda::detail::TensorInfo<T,IndexType> ti)1304 tensorInfoLegacyIfScalar(cuda::detail::TensorInfo<T, IndexType> ti) {
1305   if (ti.dims == 0) {
1306     ti.dims = 1;
1307     ti.sizes[0] = 1;
1308     ti.strides[0] = 1;
1309   }
1310   return ti;
1311 }
1312 
1313 }
1314 
1315 template <typename scalar_t>
index_select_out_cuda_impl(Tensor & out,const Tensor & self,long dim,const Tensor & index)1316 void index_select_out_cuda_impl(
1317     Tensor& out,
1318     const Tensor& self,
1319     long dim,
1320     const Tensor& index) {
1321   uint64_t numIndices = index.numel();
1322   uint64_t selfDims = self.dim() == 0 ? 1 : self.dim();
1323 
1324   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1325 
1326   TORCH_CHECK(
1327       index.dim() <= 1, "Index is supposed to be an empty tensor or a vector");
1328   TORCH_CHECK(
1329       !(self.dim() == 0 && numIndices != 1), "index_select(): Index to scalar can have only 1 value, got ", numIndices, " value(s)");
1330   TORCH_CHECK(dim < selfDims, "Indexing dim is out of bounds");
1331 
1332   std::vector<int64_t> newSize = self.sizes().vec();
1333   if (self.dim() > 0) {
1334     newSize[dim] = numIndices;
1335   }
1336 
1337   if (self.is_quantized()){
1338       out = at::empty_quantized(newSize, out);
1339   } else {
1340     at::native::resize_output(out, newSize);
1341   }
1342 
1343   uint64_t outTotalSize = out.numel();
1344   if (outTotalSize == 0) {
1345     return;
1346   }
1347 
1348   bool indContig = index.is_contiguous();
1349 
1350   // The `self` is partitioned into two parts:
1351   // -the size of each slice we are indexing, which is the
1352   // total size of the tensor ignoring dimension `dim`;
1353   // -the number of indices we are choosing, which is the total size
1354   // of the tensor `indices`.
1355   uint64_t selfSelectDimSize = self.dim() == 0 ? 1 : self.size(dim);
1356   uint64_t sliceSize = outTotalSize / numIndices;
1357 
1358   int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1359 
1360 #define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM)         \
1361   indexSelectSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM>     \
1362     <<<smallIndexGrid, smallIndexBlock, 0, stream>>>(                                   \
1363       outInfo, selfInfo, indicesInfo,                                                   \
1364       outSelectDim, selfSelectDim, static_cast<TYPE>(sliceSize),                        \
1365       selfSelectDimSize);                                                               \
1366   C10_CUDA_KERNEL_LAUNCH_CHECK();
1367 
1368 #define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE,                           \
1369                     DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR)                   \
1370   indexSelectLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE,                       \
1371                         DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR>               \
1372     <<<largeIndexGrid, largeIndexBlock, 0, stream>>>(                          \
1373       outInfo, selfInfo, indicesInfo,                                          \
1374       outSelectDim, selfSelectDim, static_cast<TYPE>(outTotalSize),            \
1375       static_cast<TYPE>((IDX_IS_MAJOR) ? sliceSize : numIndices),              \
1376       selfSelectDimSize);                                                      \
1377   C10_CUDA_KERNEL_LAUNCH_CHECK();
1378 
1379   dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t) (mpc * 8)));
1380   dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
1381 
1382   dim3 largeIndexGrid(std::min(ceil_div(outTotalSize, (uint64_t)128), (uint64_t) (mpc * 8)));
1383   // for issue https://github.com/pytorch/pytorch/issues/130806 there are two problems
1384   // 1: ptrdiff_t was used but it is signed int,  outTotalSize of 2147483648 can cause overflow
1385   // 2: On ROCm, std::min -> ::min did not work as expected on when outTotalSize>=2147483648
1386   dim3 largeIndexBlock( (outTotalSize < 128) ? outTotalSize : 128 );
1387   if (cuda::detail::canUse32BitIndexMath(out) &&
1388       cuda::detail::canUse32BitIndexMath(self) &&
1389       cuda::detail::canUse32BitIndexMath(index)) {
1390     auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, unsigned int>(out));
1391     int outSelectDim = outInfo.collapseDims(dim);
1392     outInfo.reduceDim(outSelectDim);
1393 
1394     auto  selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<const scalar_t, unsigned int>(self));
1395     int selfSelectDim = selfInfo.collapseDims(dim);
1396     selfInfo.reduceDim(selfSelectDim);
1397 
1398     AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () {
1399       auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<const index_t, unsigned int>(index));
1400       indicesInfo.collapseDims();
1401 
1402       // A reasonable choice for when to have each thread iterate over
1403       // indices to choose
1404       if (numIndices <= 16) {
1405         if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
1406           SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
1407         } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
1408           SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
1409         } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
1410           SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
1411         } else {
1412           SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
1413         }
1414       } else {
1415         bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim);
1416 
1417         if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
1418           LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
1419         } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
1420           if (indexIsMajor) {
1421             LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
1422           } else {
1423             LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
1424           }
1425         } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
1426           if (indexIsMajor) {
1427             LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
1428           } else {
1429             LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
1430           }
1431         } else {
1432           LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
1433         }
1434       }
1435     });
1436   } else {
1437     auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, uint64_t>(out));
1438     int outSelectDim = outInfo.collapseDims(dim);
1439     outInfo.reduceDim(outSelectDim);
1440 
1441     auto selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<const scalar_t, uint64_t>(self));
1442     int selfSelectDim = selfInfo.collapseDims(dim);
1443     selfInfo.reduceDim(selfSelectDim);
1444     AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () {
1445       auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<const index_t, uint64_t>(index));
1446       indicesInfo.collapseDims();
1447 
1448       LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
1449     });
1450   }
1451 #undef SMALL_INDEX
1452 #undef LARGE_INDEX
1453 }
1454 } // anonymous namespace
1455 
index_select_out_cuda(const Tensor & self,int64_t dim,const Tensor & index,Tensor & out)1456 Tensor& index_select_out_cuda(
1457     const Tensor& self,
1458     int64_t dim,
1459     const Tensor& index,
1460     Tensor& out) {
1461   static constexpr string_view DIM_WARNING =
1462       "Tensor too large or too many (> 25) dimensions";
1463   TORCH_CHECK(
1464       at::cuda::check_device({out, self, index}),
1465       "Input, output and indices must be on the current device");
1466   at::assert_no_internal_overlap(out);
1467   at::assert_no_overlap(out, self);
1468   at::assert_no_overlap(out, index);
1469 
1470   dim = at::maybe_wrap_dim(dim, self);
1471   TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING);
1472   TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING);
1473   if (self.is_quantized()){
1474     TORCH_CHECK(
1475       self.qscheme() == kPerTensorAffine,
1476       "Only per_tensor quantized quantized tensors are supported by index_select.")
1477     AT_DISPATCH_QINT_TYPES(out.scalar_type(), "index_select_quant_cuda", [&] {
1478       index_select_out_cuda_impl<scalar_t>(out, self, dim, index);
1479     });
1480   } else {
1481     AT_DISPATCH_V2(
1482         out.scalar_type(),
1483         "index_select_cuda",
1484         AT_WRAP([&] { index_select_out_cuda_impl<scalar_t>(out, self, dim, index); }),
1485         AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
1486         kComplexHalf,
1487         kHalf,
1488         kBool,
1489         kBFloat16
1490         );
1491   }
1492 
1493   return out;
1494 }
1495 
index_select_cuda(const Tensor & self,int64_t dim,const Tensor & index)1496 Tensor index_select_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
1497   Tensor out = at::empty({0}, self.options());
1498   at::native::index_select_out_cuda(self, dim, index, out);
1499   return out;
1500 }
1501 
index_select_quantized_cuda(const Tensor & self,int64_t dim,const Tensor & index)1502 Tensor index_select_quantized_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
1503   TORCH_CHECK(
1504     self.qscheme() == kPerTensorAffine,
1505     "Only per_tensor quantized quantized tensors are supported by index_select.")
1506   Tensor out = at::empty_quantized({0}, self);
1507   at::native::index_select_out_cuda(self, dim, index, out);
1508   return out;
1509 }
1510 
1511 namespace {
1512 
masked_fill_kernel(TensorIterator & iter,const Scalar & value)1513 void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
1514   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
1515       kBool, kHalf, kBFloat16, kComplexHalf, iter.common_dtype(), "masked_fill_", [&]() {
1516         const auto value_ = value.to<scalar_t>();
1517         gpu_kernel(
1518             iter, [value_] GPU_LAMBDA(scalar_t self, bool mask) -> scalar_t {
1519               if (mask) {
1520                 return value_;
1521               }
1522               return self;
1523             });
1524       });
1525 }
1526 
1527 template <typename scalar_t>
cuda_masked_fill_kernel_quantized(TensorIterator & iter,scalar_t quantized_val)1528 void cuda_masked_fill_kernel_quantized(TensorIterator& iter, scalar_t quantized_val) {
1529     gpu_kernel(
1530         iter, [quantized_val] GPU_LAMBDA(scalar_t self, bool mask) -> scalar_t {
1531           if (mask) {
1532             return quantized_val;
1533           }
1534           return self;
1535     });
1536 }
1537 
masked_fill_kernel_quantized(TensorIterator & iter,const Scalar & value,double scale,int zero_point)1538 void masked_fill_kernel_quantized(TensorIterator& iter, const Scalar& value, double scale, int zero_point) {
1539   TORCH_CHECK(iter.input_dtype(1) == at::ScalarType::Bool, "masked_fill only supports boolean masks, ",
1540     "but got dtype ", iter.input_dtype(1));
1541   AT_DISPATCH_QINT_TYPES(
1542       iter.common_dtype(), "masked_fill_", [&]() {
1543         float float_val = value.to<float>();
1544         const auto quantized_val = quantize_val<scalar_t>(scale, zero_point, float_val);
1545 
1546         cuda_masked_fill_kernel_quantized<scalar_t>(iter, quantized_val);
1547     });
1548 }
1549 
1550 REGISTER_CUDA_DISPATCH(masked_fill_kernel_quantized_stub, &masked_fill_kernel_quantized);
1551 
1552 } // anonymous namespace
1553 
masked_fill__cuda(Tensor & self,const Tensor & mask,const Scalar & value)1554 Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Scalar& value) {
1555   TORCH_CHECK(self.device() == mask.device(), "expected self and mask to be on the same device, but got mask on ",
1556     mask.device(), " and self on ", self.device());
1557   TORCH_CHECK(mask.scalar_type() == kBool,
1558     "masked_fill only supports boolean masks, but got dtype ", mask.scalar_type());
1559   auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
1560   if (at::has_internal_overlap(self) == MemOverlap::Yes) {
1561     TORCH_WARN(
1562       "Use of masked_fill_ on expanded tensors is deprecated. "
1563       "Please clone() the tensor before performing this operation. "
1564       "This also applies to advanced indexing e.g. tensor[mask] = scalar");
1565   }
1566   at::assert_no_partial_overlap(self, mask);
1567 
1568   c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_fill_");
1569 
1570   auto iter = TensorIteratorConfig()
1571       .set_check_mem_overlap(false)
1572       .check_all_same_dtype(false)
1573       .resize_outputs(false)
1574       .add_output(self)
1575       .add_const_input(self)
1576       .add_const_input(*b_mask)
1577       .build();
1578 
1579   masked_fill_kernel(iter, value);
1580   namedinference::propagate_names_if_nonempty(self, maybe_outnames);
1581   return self;
1582 }
1583 
masked_fill__cuda(Tensor & self,const Tensor & mask,const Tensor & value)1584 Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & value) {
1585   TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
1586       "with ", value.dim(), " dimension(s).");
1587   // We hit this function if either of the input tensor lives on CUDA.
1588   // It is ok, if `value` is `CPU` tensor but we should not allow `self` or
1589   // `mask` to be CPU tensor. Check for `self` and `mask` being on same device
1590   // exists in `masked_fill__cuda` (Scalar version).
1591   TORCH_CHECK(!self.device().is_cpu(), "masked_fill_: Expected inputs to be on same device")
1592   return masked_fill__cuda(self, mask, value.item());
1593 }
1594 
1595 namespace {
1596 
1597 // ForwardIt: only legacy random access iterator is supported.
1598 template<class ForwardIt, class T, bool is_lower = true>
1599 static __host__ __device__ __forceinline__
find_bound(ForwardIt first,ForwardIt last,const T & value)1600 ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) {
1601     ForwardIt it;
1602     typename std::iterator_traits<ForwardIt>::difference_type count, step;
1603     // NOTE: std::distance(first, last) compiles but produces wrong results here,
1604     // so only legacy random access iterators are safe in this code.
1605     count = last - first;
1606 
1607     while (count > 0) {
1608       it = first;
1609       step = count / 2;
1610       // avoiding std::advance(it, step),
1611       // although it does work unlike std::distance
1612       it += step;
1613       if (is_lower ? *it < value : value >= *it) {
1614         first = ++it;
1615         count -= step + 1;
1616       }
1617       else {
1618         count = step;
1619       }
1620     }
1621     return first;
1622 }
1623 
1624 }
1625 
index_select_sparse_cuda(const Tensor & self,int64_t dim,const Tensor & index)1626 Tensor index_select_sparse_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
1627   const auto ndim = self.dim();
1628   TORCH_CHECK_INDEX(ndim, "index_select() cannot be applied to a 0-dim tensor.");
1629   TORCH_CHECK_INDEX(
1630       index.dim() == 1 && index.dtype() == at::kLong && index.options().layout() == at::kStrided,
1631       "index_select() argument index must be 1-D strided (non-sparse) long-tensor.");
1632   dim = maybe_wrap_dim(dim, ndim);
1633   const auto size = self.size(dim);
1634   const auto sparse_dim = self.sparse_dim();
1635   const auto dense_dim = self.dense_dim();
1636   const auto indices = self._indices();
1637   const auto values = self._values();
1638   const auto nnz = values.size(0);
1639   const auto index_len = index.size(0);
1640   auto res_sizes = self.sizes().vec();
1641   res_sizes[dim] = index_len;
1642 
1643   // If indexing into sparse dimensions
1644   if (dim < sparse_dim) {
1645     const auto make_output = [
1646       dim, sparse_dim, dense_dim, res_sizes, &self, &indices, &values
1647     ](
1648         const Tensor& selected_dim_indices,
1649         const Tensor& res_dim_indices
1650     ) -> Tensor {
1651       auto res_indices = indices.index_select(1, selected_dim_indices);
1652       res_indices[dim] = res_dim_indices;
1653       const auto res_values = values.index_select(0, selected_dim_indices);
1654 
1655       return at::_sparse_coo_tensor_with_dims_and_tensors(
1656           sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
1657     };
1658 
1659     // short-circuit if index is empty
1660     if (!index_len) {
1661       return make_output(index, index);
1662     }
1663 
1664     const auto nneg_index = [&index, size]() -> Tensor {
1665       auto nneg_index = at::empty_like(index, at::MemoryFormat::Contiguous);
1666 
1667       auto iter = TensorIteratorConfig()
1668         .add_output(nneg_index)
1669         .add_input(index)
1670         .build();
1671 
1672       AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_sparse_cuda", [&]() {
1673           gpu_kernel(iter, [size] GPU_LAMBDA (index_t idx) -> index_t {
1674               CUDA_KERNEL_ASSERT(idx >= -size && idx < size
1675                   && "index_select(): index out of bounds");
1676               return idx < 0 ? idx + size : idx;
1677           });
1678       });
1679       return nneg_index;
1680     }();
1681 
1682     const auto dim_indices = indices[dim].contiguous();
1683     const auto idx_nneg_index = at::arange(index_len, nneg_index.options());
1684     const auto idx_dim_indices = at::arange(nnz, dim_indices.options());
1685 
1686     Tensor sorted_dim_indices, argsort_dim_indices;
1687     std::tie(sorted_dim_indices, argsort_dim_indices) = [&]() -> std::tuple<Tensor, Tensor> {
1688       if (dim == 0 && self.is_coalesced()) {
1689         return std::make_tuple(dim_indices, idx_dim_indices);
1690       }
1691       else {
1692         return dim_indices.sort();
1693       }
1694     }();
1695 
1696     Tensor intrsc_counts_nneg_index;
1697     Tensor intrsc_first_match_nneg_index;
1698     std::tie(intrsc_counts_nneg_index, intrsc_first_match_nneg_index) = [&]() -> std::tuple<Tensor, Tensor> {
1699       auto intrsc_counts_nneg_index = at::zeros_like(nneg_index);
1700       auto intrsc_first_match_nneg_index = at::zeros_like(nneg_index);
1701 
1702       auto iter = TensorIteratorConfig()
1703         .add_output(intrsc_first_match_nneg_index)
1704         .add_input(nneg_index)
1705         .add_input(idx_nneg_index)
1706         .build();
1707 
1708       AT_DISPATCH_INDEX_TYPES(nneg_index.scalar_type(), "index_select_sparse_cuda", [&]() {
1709           index_t* ptr_intrsc_counts_nneg_index = intrsc_counts_nneg_index.mutable_data_ptr<index_t>();
1710           const index_t* ptr_sorted_dim_indices = sorted_dim_indices.const_data_ptr<index_t>();
1711           gpu_kernel(
1712               iter,
1713               [ptr_intrsc_counts_nneg_index, ptr_sorted_dim_indices, nnz] GPU_LAMBDA (
1714                 index_t idx_val, index_t idx_idx
1715               ) -> index_t {
1716                 auto* lb = find_bound<const index_t*, index_t, true>(
1717                   ptr_sorted_dim_indices,
1718                   ptr_sorted_dim_indices + nnz,
1719                   idx_val
1720                 );
1721                 auto* ub = find_bound<const index_t*, index_t, false>(
1722                   ptr_sorted_dim_indices,
1723                   ptr_sorted_dim_indices + nnz,
1724                   idx_val
1725                 );
1726                 const auto idx_count = ub - lb;
1727                 ptr_intrsc_counts_nneg_index[idx_idx] = idx_count;
1728 
1729                 return lb - ptr_sorted_dim_indices;
1730               }
1731           );
1732       });
1733 
1734       return std::make_tuple(intrsc_counts_nneg_index, intrsc_first_match_nneg_index);
1735     }();
1736 
1737     // Unavoidable sync since the shape of the result is not known in advance
1738     auto res_len = intrsc_counts_nneg_index.sum().item<int64_t>();
1739     // Short-circuit if empty intersection
1740     if (!res_len) {
1741       auto empty_idx = at::empty({0}, nneg_index.options());
1742       return make_output(empty_idx, empty_idx);
1743     }
1744 
1745     Tensor selected_dim_indices, res_dim_indices;
1746     std::tie(selected_dim_indices, res_dim_indices) = [&]() -> std::tuple<Tensor, Tensor> {
1747       auto res_dim_indices = at::empty({res_len}, nneg_index.options());
1748       auto selected_dim_indices = at::empty_like(res_dim_indices);
1749       auto selected_dim_indices_offsets = intrsc_counts_nneg_index.cumsum(0)
1750         .sub_(intrsc_counts_nneg_index);
1751 
1752       // Need to have output as TensorIterator does not allow having void lambdas.
1753       auto dummy_output = at::empty({1}, dim_indices.options()).expand(IntArrayRef({index_len}));
1754       auto iter = TensorIteratorConfig()
1755         .add_output(dummy_output)
1756         // All iterations map to a single element in dummy_output by design,
1757         // hence removed output memory overlap check.
1758         .set_check_mem_overlap(false)
1759         .add_input(idx_nneg_index)
1760         .add_input(intrsc_counts_nneg_index)
1761         .add_input(selected_dim_indices_offsets)
1762         .add_input(intrsc_first_match_nneg_index)
1763         .build();
1764 
1765       AT_DISPATCH_INDEX_TYPES(nneg_index.scalar_type(), "index_select_sparse_cuda", [&]() {
1766           index_t* ptr_res_dim_indices = res_dim_indices.mutable_data_ptr<index_t>();
1767           index_t* ptr_selected_dim_indices = selected_dim_indices.mutable_data_ptr<index_t>();
1768           const index_t* ptr_argsort_dim_indices = argsort_dim_indices.const_data_ptr<index_t>();
1769           gpu_kernel(
1770               iter,
1771               [ptr_res_dim_indices, ptr_selected_dim_indices, ptr_argsort_dim_indices] GPU_LAMBDA (
1772                 index_t idx_idx, index_t count, index_t offset, index_t first_match
1773               ) -> index_t {
1774                 index_t* __restrict__ ptr_res_dim_indices_out = ptr_res_dim_indices + offset;
1775                 const index_t* __restrict__ ptr_argsort_dim_indices_in = ptr_argsort_dim_indices + first_match;
1776                 index_t* __restrict__ ptr_selected_dim_indices_out = ptr_selected_dim_indices + offset;
1777                 for (index_t i = 0; i < count; ++i) {
1778                   *ptr_res_dim_indices_out++ = idx_idx;
1779                   *ptr_selected_dim_indices_out++ = *ptr_argsort_dim_indices_in++;
1780                 }
1781 
1782                 // A dummy return scalar for a dummy output
1783                 return static_cast<index_t>(1);
1784               }
1785           );
1786       });
1787 
1788       return std::make_tuple(selected_dim_indices, res_dim_indices);
1789     }();
1790 
1791     return make_output(selected_dim_indices, res_dim_indices);
1792   }
1793   // If indexing into dense dimensions
1794   else {
1795     // It is sufficient to just perform `index_select` on values
1796     // if `dim` refers to dense dimensions.
1797     const auto res_values = values.index_select(dim - sparse_dim + 1, index);
1798 
1799     return _sparse_coo_tensor_with_dims_and_tensors(
1800         sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
1801   }
1802 }
1803 
1804 
1805 } // at::native
1806