1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/TensorAdvancedIndexing.h>
3 #include <ATen/native/IndexingUtils.h>
4 #include <ATen/native/quantized/IndexKernel.h>
5 #include <ATen/native/cuda/KernelUtils.cuh>
6
7 #include <ATen/core/Tensor.h>
8 #include <ATen/ceil_div.h>
9 #include <ATen/Dispatch.h>
10 #include <ATen/Dispatch_v2.h>
11 #include <ATen/ExpandUtils.h>
12 #include <ATen/MemoryOverlap.h>
13 #include <ATen/TensorOperators.h>
14 #include <ATen/native/TensorIterator.h>
15 #include <ATen/native/cuda/Loops.cuh>
16 #include <ATen/native/Resize.h>
17 #include <ATen/cuda/detail/IndexUtils.cuh>
18 #include <ATen/cuda/CUDAUtils.h>
19 #include <ATen/cuda/DeviceUtils.cuh>
20
21 #ifndef AT_PER_OPERATOR_HEADERS
22 #include <ATen/Functions.h>
23 #include <ATen/NativeFunctions.h>
24 #else
25 #include <ATen/ops/_assert_async.h>
26 #include <ATen/ops/arange.h>
27 #include <ATen/ops/empty.h>
28 #include <ATen/ops/zeros_like.h>
29 #include <ATen/ops/ones_like.h>
30 #include <ATen/ops/empty_quantized.h>
31 #include <ATen/ops/index_add_native.h>
32 #include <ATen/ops/index_reduce_native.h>
33 #include <ATen/ops/index_select_native.h>
34 #include <ATen/ops/masked_fill_native.h>
35 #include <ATen/ops/_sparse_coo_tensor_with_dims_and_tensors.h>
36 #endif
37
38 #include <ATen/cuda/CUDAContext.h>
39 #include <ATen/cuda/cub.h>
40 #include <c10/util/irange.h>
41 #include <c10/core/QScheme.h>
42 #include <ATen/native/quantized/AffineQuantizerBase.h>
43
44 #include <limits>
45
46 #include <c10/macros/Macros.h>
47
48 namespace {
49 template <typename scalar_t, int SZ>
indexing_backward_kernel(const int64_t * sorted_indices,const int64_t * indices,const scalar_t * grad_output,scalar_t * grad_weight,int64_t numel,int64_t stride,int64_t stride_before,int64_t outer_dim,bool accumulate)50 __global__ void indexing_backward_kernel(
51 const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
52 int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
53 //numel is total number of flattened indices, not expanded to dimensions that are not indexed.
54 //stride is the cumulative size of the not-indexed last dimensions
55 //stride_before is the stride of the dimension immediately preceding first indexed dimension
56 //if indexing starts from the 0th dimension, stride_before does not matter because blockIdx.z will be 0 in this case
57 //outer_dim is number of elements in the first unindexed dimensions
58 using opmath_t = at::opmath_type<scalar_t>;
59
60 // Each warp is responsible for an input into the LookupTable.
61 // If the preceding input has the same destination index as this input, then the warp
62 // exits immediately. The warp also processes subsequent inputs with the
63 // same value.
64 //
65 // Input Warp
66 // 1 <warp 1>
67 // 1 <warp 1> (<warp 2> exits without doing any work)
68 // 5 <warp 3>
69 // 8 <warp 4>
70
71 // Number of values processed by each thread (grain size)
72 for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
73 int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
74 if (idx < numel
75 && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
76 do {
77 int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
78 // if not accumulate, we only keep the last duplicate index so skip those before it
79 if (!accumulate && (idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
80 idx++;
81 continue;
82 }
83 const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
84 const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
85 const opmath_t scale = (opmath_t)1.0;
86
87 opmath_t gradient[SZ];
88 opmath_t weight[SZ];
89
90 while (start_feature < stride) {
91 #pragma unroll
92 for (int ii = 0; ii < SZ; ii++) {
93 int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
94 if (feature_dim < stride) {
95 gradient[ii] = static_cast<opmath_t>(grad_output[grad_row + feature_dim]);
96 if (accumulate) {
97 weight[ii] = static_cast<opmath_t>(grad_weight[weight_row + feature_dim]);
98 }
99 }
100 }
101
102 #pragma unroll
103 for (int ii = 0; ii < SZ; ii++) {
104 if (accumulate) {
105 weight[ii] += gradient[ii] * scale;
106 } else {
107 weight[ii] = gradient[ii] * scale;
108 }
109 }
110
111 #pragma unroll
112 for (int ii = 0; ii < SZ; ii++) {
113 int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
114 if (feature_dim < stride) {
115 grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(weight[ii]);
116 }
117 }
118 start_feature += gridDim.y * blockDim.x * SZ;
119 }
120
121 idx++;
122 } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]);
123 }
124 }
125 }
126
127 template <typename scalar_t>
indexing_backward_kernel_stride_1(const int64_t * sorted_indices,const int64_t * indices,const scalar_t * grad_output,scalar_t * grad_weight,int64_t numel,int64_t stride,int64_t stride_before,int64_t outer_dim,bool accumulate)128 __global__ void indexing_backward_kernel_stride_1(
129 const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
130 int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
131 using opmath_t = at::opmath_type<scalar_t>;
132
133 // Number of values processed by each thread (grain size)
134 for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
135 int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
136 int64_t crnt_sorted_idx = sorted_indices[idx];
137
138 if ((idx < numel) &&
139 (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]))
140 {
141 // Determine the number of duplicates in advance
142 int64_t num_duplicates = 1;
143 while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
144 num_duplicates++;
145 }
146
147 // Continue computing weights
148 const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
149 int64_t grad_row = 0;
150 const opmath_t scale = (opmath_t)1.0;
151
152 if (!accumulate) {
153 grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
154 grad_weight[weight_row] =
155 static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row]) * scale);
156 } else {
157 opmath_t gradient = (opmath_t)0.0;
158
159 int laneIdx = threadIdx.x % C10_WARP_SIZE;
160 int64_t num_warp_passes = num_duplicates / C10_WARP_SIZE;
161 for (int64_t i = 0; i < num_warp_passes; ++i) {
162 grad_row = ((int64_t) indices[idx + i * C10_WARP_SIZE + laneIdx]) * stride + z * numel * stride;
163 gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
164 }
165 WARP_SYNC();
166 for (int offset = C10_WARP_SIZE / 2; offset > 0; offset /= 2) {
167 gradient += WARP_SHFL_DOWN(gradient, offset);
168 }
169
170 if (laneIdx == 0) {
171 for (int64_t i = num_warp_passes * C10_WARP_SIZE; i < num_duplicates; ++i) {
172 grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
173 gradient += static_cast<opmath_t>(grad_output[grad_row]) * scale;
174 }
175
176 grad_weight[weight_row] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row]) + gradient);
177 }
178 }
179 }
180 }
181 }
182
183 template <typename scalar_t>
indexing_backward_kernel_small_stride(const int64_t * sorted_indices,const int64_t * indices,const scalar_t * grad_output,scalar_t * grad_weight,int64_t numel,int64_t stride,int64_t stride_before,int64_t outer_dim,bool accumulate)184 __global__ void indexing_backward_kernel_small_stride(
185 const int64_t* sorted_indices, const int64_t* indices, const scalar_t* grad_output, scalar_t* grad_weight,
186 int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim, bool accumulate) {
187 using opmath_t = at::opmath_type<scalar_t>;
188
189 // Number of values processed by each thread (grain size)
190 for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
191 int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
192 int64_t tidx = threadIdx.x;
193 int64_t crnt_sorted_idx = sorted_indices[idx];
194
195 if ((idx < numel) &&
196 (tidx < stride) &&
197 (idx == 0 || crnt_sorted_idx != sorted_indices[idx - 1]))
198 {
199 // Determine the number of duplicates in advance
200 int64_t num_duplicates = 1;
201 while (((idx + num_duplicates) < numel) && (sorted_indices[idx + num_duplicates] == crnt_sorted_idx)) {
202 num_duplicates++;
203 }
204
205 // Continue computing weights
206 const int64_t weight_row = crnt_sorted_idx * stride + z * stride_before;
207 int64_t grad_row = 0;
208 const opmath_t scale = (opmath_t)1.0;
209
210 if (!accumulate) {
211 grad_row = ((int64_t)indices[idx + num_duplicates - 1]) * stride + z * numel * stride;
212 grad_weight[weight_row + tidx] =
213 static_cast<scalar_t>(static_cast<opmath_t>(grad_output[grad_row + tidx]) * scale);
214 } else {
215 opmath_t gradient = (opmath_t)0.0;
216 for (int64_t i = 0; i < num_duplicates; ++i) {
217 grad_row = ((int64_t) indices[idx + i]) * stride + z * numel * stride;
218 gradient += static_cast<opmath_t>(grad_output[grad_row + tidx]) * scale;
219 }
220
221 grad_weight[weight_row + tidx] = static_cast<scalar_t>(static_cast<opmath_t>(grad_weight[weight_row + tidx]) + gradient);
222 }
223 }
224 }
225 }
226
227 template <typename scalar_t, int SZ>
indexing_backward_kernel_quantized(const int64_t * sorted_indices,const int64_t * indices,const float * grad_output,scalar_t * grad_weight,int64_t numel,int64_t stride,int64_t stride_before,int64_t outer_dim,float inv_scale,int zero_point,int64_t qmin,int64_t qmax)228 __global__ void indexing_backward_kernel_quantized(
229 const int64_t* sorted_indices, const int64_t* indices, const float* grad_output, scalar_t* grad_weight,
230 int64_t numel, int64_t stride, int64_t stride_before, int64_t outer_dim,
231 float inv_scale, int zero_point, int64_t qmin, int64_t qmax) {
232
233 // This implementation is adopted from indexing_backward_kernel above.
234 using opmath_t = at::opmath_type<float>;
235 for (int64_t z = blockIdx.z; z < outer_dim; z += gridDim.z){
236 int64_t idx = blockIdx.x * blockDim.y + threadIdx.y;
237 if (idx < numel
238 && (idx == 0 || sorted_indices[idx] != sorted_indices[idx - 1])){
239 do {
240 int64_t start_feature = threadIdx.x + blockIdx.y * blockDim.x * SZ;
241 // we only keep the last duplicate index so skip those before it
242 if ((idx < numel - 1) && sorted_indices[idx] == sorted_indices[idx + 1]) {
243 idx++;
244 continue;
245 }
246 const int64_t weight_row = ((int64_t) sorted_indices[idx]) * stride + z * stride_before;
247 const int64_t grad_row = ((int64_t) indices[idx]) * stride + z * numel * stride;
248 const opmath_t scale = (opmath_t)1.0;
249
250 opmath_t gradient[SZ];
251 opmath_t weight[SZ];
252
253 while (start_feature < stride) {
254 #pragma unroll
255 for (int ii = 0; ii < SZ; ii++) {
256 int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
257 if (feature_dim < stride) {
258 gradient[ii] = static_cast<opmath_t>(grad_output[grad_row + feature_dim]);
259 }
260 }
261
262 #pragma unroll
263 for (int ii = 0; ii < SZ; ii++) {
264 weight[ii] = gradient[ii] * scale;
265 }
266
267 #pragma unroll
268 for (int ii = 0; ii < SZ; ii++) {
269 int64_t feature_dim = start_feature + ii * C10_WARP_SIZE;
270 if (feature_dim < stride) {
271 // we do quantization here
272 int64_t qvalue = static_cast<int64_t>(zero_point + nearbyintf(weight[ii]* inv_scale));
273 qvalue = min(max(qvalue, qmin), qmax);
274 grad_weight[weight_row + feature_dim] = static_cast<scalar_t>(qvalue);
275 }
276 }
277 start_feature += gridDim.y * blockDim.x * SZ;
278 }
279
280 idx++;
281 } while (idx < numel && sorted_indices[idx] == sorted_indices[idx - 1]);
282 }
283 }
284 }
285
286
287 }
288
289
290 namespace at::native {
291
292 namespace {
293
294 class ReduceMultiply {
295 public:
296 template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const297 constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
298 (void)numel; // suppress unused warning
299 gpuAtomicMul(self_data_start + index, *src_data);
300 }
301 };
302 static ReduceMultiply reduce_multiply;
303
304 class ReduceAdd {
305 public:
306 template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const307 constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
308 fastAtomicAdd(self_data_start, index, numel, *src_data, true);
309 }
310 };
311 static ReduceAdd reduce_add;
312
313 class ReduceMinimum {
314 public:
315 template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const316 constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
317 (void)numel; // suppress unused warning
318 gpuAtomicMin(self_data_start + index, *src_data);
319 }
320 };
321 static ReduceMinimum reduce_minimum;
322
323 class ReduceMaximum {
324 public:
325 template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const326 constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
327 (void)numel; // suppress unused warning
328 gpuAtomicMax(self_data_start + index, *src_data);
329 }
330 };
331 static ReduceMaximum reduce_maximum;
332
333 }
334
wrapIndexOnce(const Tensor & index,int64_t dim,int64_t dim_size,bool check_range=true)335 static Tensor wrapIndexOnce(const Tensor & index, int64_t dim, int64_t dim_size, bool check_range=true) {
336 //we don't need to check range in backward - if there were out of bounds indices forward should already have errored out
337 if (index.numel() != 0 && check_range) {
338 at::_assert_async(index.max() < dim_size);
339 at::_assert_async(index.min() >= -dim_size);
340 }
341 return index.remainder(dim_size);
342 }
343
computeLinearStride(const Tensor & tensor)344 static std::vector<int64_t> computeLinearStride(const Tensor & tensor) {
345 // computes the stride as if tensor were contiguous
346 auto sizes = tensor.sizes();
347 std::vector<int64_t> stride(tensor.dim());
348 if (stride.empty()) {
349 return stride;
350 }
351 stride[tensor.dim() - 1] = 1;
352 std::partial_sum(sizes.rbegin(), sizes.rend() - 1, stride.rbegin() + 1, std::multiplies<int64_t>());
353 return stride;
354 }
355
356 static std::tuple<Tensor, int64_t, int64_t, int64_t>
computeLinearIndex(const Tensor & src,TensorList indices,bool check_range)357 computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
358 auto strides = computeLinearStride(src);
359 const auto& device = src.options().device();
360
361 // Compute the linear index by multiplying the indexing tensors by the
362 // stride and summing them. All the indexing tensors have the same shape at
363 // this point. We also compute the number of dimensions before and after that
364 // are not being index.
365 Tensor linearIndex;
366 int64_t nElemBefore = 1, nElemAfter = 1, strideBefore =0;
367 for (const auto i: c10::irange(src.dim())) {
368 if (indices[i].defined()) {
369 // Cast index to the longType matching src's device
370 // This allows us to support ie indexing a cuda tensor with a cpu tensor
371 Tensor index = (wrapIndexOnce(indices[i], i, src.size(i), check_range) * strides[i]).to(device);
372 if (linearIndex.defined()) {
373 linearIndex += index;
374 } else {
375 linearIndex = index;
376 if (i>0) {
377 strideBefore = src.stride(i-1); // stride after undefined dimensions
378 }
379 }
380 } else if (linearIndex.defined()) {
381 nElemAfter *= src.size(i);
382 } else {
383 nElemBefore *= src.size(i);
384 }
385 }
386
387 return std::make_tuple(std::move(linearIndex), nElemBefore, strideBefore, nElemAfter);
388 }
389
390
makeLinearIndex(Tensor self,IOptTensorListRef orig,bool check_range)391 static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, IOptTensorListRef orig, bool check_range) {
392 checkIndexTensorTypes(orig, /*allow_int*/true);
393 // first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
394 auto indices = expandTensors(self, orig);
395 for (auto & i : indices) {
396 if (i.defined() && i.dtype() == at::kInt) {
397 i = i.to(at::kLong);
398 }
399 }
400 // next broadcast all index tensors together
401 indices = expand_outplace(indices);
402 // add missing null Tensors so that it matches self.dim()
403 while (indices.size() < (size_t)self.dim()) {
404 indices.emplace_back();
405 }
406 // if the non-null indices are not all adjacent, transpose self and indices
407 // together so that they're adjacent at the front
408 std::vector<int64_t> inversePerm;
409 if (!hasContiguousSubspace(indices)) {
410 std::tie(self, indices, inversePerm) = transposeToFrontAndInvPerm(self, indices);
411 }
412 auto [linearIndex, nElemBefore, strideBefore, nElemAfter] = computeLinearIndex(self, indices, check_range);
413 return std::make_tuple(linearIndex, self, nElemBefore, strideBefore, nElemAfter, inversePerm);
414 }
415
416
417 void index_put_with_sort_kernel_thrust_helper(Tensor &linearIndex, Tensor &orig_indices, Tensor &sorted_indices, int64_t num_indices);
418
419 namespace {
420
largestIndex(const Tensor & self)421 int64_t largestIndex(const Tensor &self) {
422 int64_t result = 0;
423 for (const auto i: c10::irange(self.dim())) {
424 result += (self.sizes()[i] - 1) * self.strides()[i];
425 }
426 return result;
427 }
428
index_put_with_sort_kernel(Tensor & self,const c10::List<std::optional<Tensor>> & indices,const Tensor & value,bool accumulate,bool unsafe)429 void index_put_with_sort_kernel(Tensor & self, const c10::List<std::optional<Tensor>>& indices, const Tensor & value, bool accumulate, bool unsafe) {
430 TORCH_CHECK(!indices.empty() || is_expandable_to(value.sizes(), self.sizes()), "shape mismatch: value tensor of shape ", value.sizes(),
431 " cannot be broadcast to indexing result of shape ", self.sizes());
432 if (indices.size() > (size_t)self.dim()) {
433 TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
434 }
435 bool self_contiguous = self.is_contiguous();
436 auto self_ = self_contiguous ? self : self.contiguous();
437 Tensor linearIndex, src, expandedValue = value;
438 int64_t nElemBefore, strideBefore, sliceSize;
439 std::vector<int64_t> inversePerm;
440 std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self_, indices, !unsafe);
441 int64_t num_indices = linearIndex.numel();
442
443 if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
444 auto expanded_size = at::DimVector(expandedValue.sizes());
445 auto size1 = expandedValue.sizes();
446 auto size2 = linearIndex.sizes();
447 if (are_expandable(size1, size2)) {
448 expanded_size = infer_size_dimvector(size1, size2);
449 }
450 if (nElemBefore > 1) {
451 expanded_size.insert(expanded_size.begin(), nElemBefore);
452 }
453 if (sliceSize > 1) {
454 expanded_size.insert(expanded_size.end(), sliceSize);
455 }
456 expandedValue = expandedValue.expand(expanded_size);
457 }
458 expandedValue = expandedValue.contiguous();
459
460 if (num_indices > 0 && sliceSize > 0) {
461 const bool permuted = !src.is_contiguous();
462 auto src_ = permuted ? src.contiguous() : src;
463 linearIndex = linearIndex.reshape(-1);
464 auto sorted_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
465 auto orig_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
466 const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
467
468 linearIndex.divide_(sliceSize, "trunc");
469
470 // cub on CUDA <= 11.2 have a bug that for small sizes
471 // cub's sort can be much slower than thrust's merge sort
472 // this bug is fixed in CUDA 11.3
473 #if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) || defined(USE_ROCM)
474 if (num_indices < 50000) {
475 index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
476 } else
477 #endif
478 {
479 // Sort the inputs into sorted with the corresponding indices
480 auto range = at::arange(num_indices, linearIndex.options());
481 // linearIndex can not be negative, and we take advantage of this
482 // fact to sort on less bits for better performance.
483 int64_t nbits = cuda::cub::get_num_bits(largestIndex(self_) / sliceSize);
484 cuda::cub::radix_sort_pairs(
485 linearIndex.const_data_ptr<int64_t>(), sorted_indices.mutable_data_ptr<int64_t>(),
486 range.const_data_ptr<int64_t>(), orig_indices.mutable_data_ptr<int64_t>(),
487 num_indices, false, 0, nbits);
488 }
489
490 TORCH_INTERNAL_ASSERT(
491 linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),
492 "number of flattened indices did not match number of elements in the value tensor: ",
493 linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel());
494 const int UNROLL = 4;
495 const int indices_per_block = 4;
496 const int warp_size = at::cuda::warp_size();
497 dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block),
498 std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size*UNROLL))),
499 std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
500 dim3 block(warp_size, indices_per_block);
501
502
503 if (sliceSize == 1) {
504 // This implementation is faster with high amounts of duplicates but could overflow
505 // if FP16 / BF16 is used
506 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
507 expandedValue.scalar_type(), "indexing_backward_kernel_stride_1", [&] {
508 indexing_backward_kernel_stride_1<scalar_t><<<grid, block, 0, stream>>>(
509 sorted_indices.const_data_ptr<int64_t>(),
510 orig_indices.const_data_ptr<int64_t>(),
511 expandedValue.const_data_ptr<scalar_t>(),
512 src_.mutable_data_ptr<scalar_t>(),
513 num_indices,
514 sliceSize,
515 strideBefore,
516 nElemBefore,
517 accumulate);
518 C10_CUDA_KERNEL_LAUNCH_CHECK();
519 });
520 } else {
521 if (sliceSize <= warp_size) {
522 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
523 expandedValue.scalar_type(), "indexing_backward_kernel_small_stride", [&] {
524 indexing_backward_kernel_small_stride<scalar_t><<<grid, block, 0, stream>>>(
525 sorted_indices.const_data_ptr<int64_t>(),
526 orig_indices.const_data_ptr<int64_t>(),
527 expandedValue.const_data_ptr<scalar_t>(),
528 src_.mutable_data_ptr<scalar_t>(),
529 num_indices,
530 sliceSize,
531 strideBefore,
532 nElemBefore,
533 accumulate);
534 C10_CUDA_KERNEL_LAUNCH_CHECK();
535 });
536 } else {
537 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(kComplexHalf, kHalf, kBool, kBFloat16,
538 expandedValue.scalar_type(), "indexing_backward", [&] {
539 indexing_backward_kernel<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
540 sorted_indices.const_data_ptr<int64_t>(),
541 orig_indices.const_data_ptr<int64_t>(),
542 expandedValue.const_data_ptr<scalar_t>(),
543 src_.mutable_data_ptr<scalar_t>(),
544 num_indices,
545 sliceSize,
546 strideBefore,
547 nElemBefore,
548 accumulate);
549 C10_CUDA_KERNEL_LAUNCH_CHECK();
550 });
551 }
552 }
553
554 if (permuted) {
555 self.copy_(src_.permute(inversePerm));
556 } else if (!self_contiguous) {
557 self.copy_(self_);
558 }
559 }
560 }
561
562 REGISTER_CUDA_DISPATCH(index_put_with_sort_stub, &index_put_with_sort_kernel);
563
index_put_with_sort_quantized(Tensor & self,const c10::List<std::optional<Tensor>> & indices,const Tensor & value,double scale,int zero_point,bool unsafe)564 void index_put_with_sort_quantized(Tensor & self, const c10::List<std::optional<Tensor>>& indices, const Tensor & value, double scale, int zero_point, bool unsafe) {
565 if (indices.size() > (size_t)self.dim()) {
566 TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
567 }
568 bool self_contiguous = self.is_contiguous();
569 auto self_ = self_contiguous ? self : self.contiguous();
570 Tensor linearIndex, src, expandedValue = value;
571 int64_t nElemBefore, strideBefore, sliceSize;
572 std::vector<int64_t> inversePerm;
573 std::tie(linearIndex, src, nElemBefore, strideBefore, sliceSize, inversePerm) = makeLinearIndex(self_, indices, !unsafe);
574 int64_t num_indices = linearIndex.numel();
575
576 if (expandedValue.numel() < num_indices * nElemBefore * sliceSize) {
577 auto expanded_size = at::DimVector(expandedValue.sizes());
578 auto size1 = expandedValue.sizes();
579 auto size2 = linearIndex.sizes();
580 if (are_expandable(size1, size2)) {
581 expanded_size = infer_size_dimvector(size1, size2);
582 }
583 if (nElemBefore > 1) {
584 expanded_size.insert(expanded_size.begin(), nElemBefore);
585 }
586 expandedValue = expandedValue.expand(expanded_size);
587 }
588 expandedValue = expandedValue.contiguous();
589
590 if (num_indices > 0 && sliceSize > 0) {
591 const bool permuted = !src.is_contiguous();
592 auto src_ = permuted ? src.contiguous() : src;
593 linearIndex = linearIndex.reshape(-1);
594 auto sorted_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
595 auto orig_indices = at::empty_like(linearIndex, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
596 const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
597
598 linearIndex.divide_(sliceSize, "trunc");
599
600 // cub on CUDA <= 11.2 have a bug that for small sizes
601 // cub's sort can be much slower than thrust's merge sort
602 // this bug is fixed in CUDA 11.3
603 #if (defined(CUDA_VERSION) && CUDA_VERSION < 11030) || defined(USE_ROCM)
604 if (num_indices < 50000) {
605 index_put_with_sort_kernel_thrust_helper(linearIndex, orig_indices, sorted_indices, num_indices);
606 } else
607 #endif
608 {
609 // Sort the inputs into sorted with the corresponding indices
610 auto range = at::arange(num_indices, linearIndex.options());
611 // linearIndex can not be negative, and we take advantage of this
612 // fact to sort on less bits for better performance.
613 int64_t nbits = cuda::cub::get_num_bits(largestIndex(self_) / sliceSize);
614 cuda::cub::radix_sort_pairs(
615 linearIndex.const_data_ptr<int64_t>(), sorted_indices.mutable_data_ptr<int64_t>(),
616 range.const_data_ptr<int64_t>(), orig_indices.mutable_data_ptr<int64_t>(),
617 num_indices, false, 0, nbits);
618 }
619
620 TORCH_INTERNAL_ASSERT(
621 linearIndex.numel()*sliceSize*nElemBefore == expandedValue.numel(),
622 "number of flattened indices did not match number of elements in the value tensor: ",
623 linearIndex.numel()*sliceSize*nElemBefore, " vs ", expandedValue.numel());
624 const int UNROLL = 4;
625 const int indices_per_block = 4;
626 const int warp_size = at::cuda::warp_size();
627 dim3 grid(ceil_div(num_indices, (int64_t) indices_per_block),
628 std::min<int>(at::cuda::getCurrentDeviceProperties()->maxGridSize[1], ceil_div(sliceSize, (int64_t) (warp_size*UNROLL))),
629 std::min(std::max<int>(1,nElemBefore), at::cuda::getCurrentDeviceProperties()->maxGridSize[2]));
630 dim3 block(warp_size, indices_per_block);
631
632 AT_DISPATCH_QINT_TYPES(
633 src.scalar_type(), "indexing_backward_quantized", [&] {
634 constexpr int64_t qmin = std::numeric_limits<typename scalar_t::underlying>::min();
635 constexpr int64_t qmax = std::numeric_limits<typename scalar_t::underlying>::max();
636 float inv_scale = 1.0f / static_cast<float>(scale);
637
638 indexing_backward_kernel_quantized<scalar_t, UNROLL><<<grid, block, 0, stream>>>(
639 sorted_indices.const_data_ptr<int64_t>(),
640 orig_indices.const_data_ptr<int64_t>(),
641 expandedValue.const_data_ptr<float>(),
642 src_.mutable_data_ptr<scalar_t>(),
643 num_indices,
644 sliceSize,
645 strideBefore,
646 nElemBefore,
647 inv_scale,
648 zero_point,
649 qmin,
650 qmax);
651 C10_CUDA_KERNEL_LAUNCH_CHECK();
652 });
653
654 if (permuted) {
655 self.copy_(src_.permute(inversePerm));
656 } else if (!self_contiguous) {
657 self.copy_(self_);
658 }
659 }
660 }
661
662 REGISTER_CUDA_DISPATCH(index_put_with_sort_quantized_stub, &index_put_with_sort_quantized);
663 } //anonymous
664
665
666 // Check tensor dimensions for index operations, and return the slice size.
getSliceSize(const Tensor & dst,int dim,const Tensor & index,const Tensor & src)667 static size_t getSliceSize(const Tensor & dst,
668 int dim,
669 const Tensor & index,
670 const Tensor & src)
671 {
672 const auto dstDims = dst.dim();
673 const auto srcDims = src.dim();
674
675 TORCH_CHECK(index.dim() <= 1, "Index must be vector or scalar");
676
677 size_t dstSliceSize = 1;
678 TORCH_CHECK(dim >= 0 && dim < dstDims, "Indexing dim ", dim, " is out of bounds");
679 for (const auto d: c10::irange(dstDims)) {
680 if (d != dim) {
681 dstSliceSize *= dst.size(d);
682 }
683 }
684
685 TORCH_CHECK(dim < srcDims, "Indexing dim ", dim, " is out of bounds");
686 TORCH_CHECK(index.numel() == src.size(dim),
687 "length of src.size[dim] is not equal to length of indices");
688
689 size_t srcSliceSize = 1;
690 bool mismatch = false;
691
692 if (dstDims != srcDims) mismatch = true;
693
694 for (const auto d: c10::irange(srcDims)) {
695 if (d != dim) {
696 srcSliceSize *= src.size(d);
697 if (!mismatch && dst.size(d) != src.size(d)) mismatch = true;
698 }
699 }
700
701 TORCH_CHECK(dstSliceSize == srcSliceSize,
702 "Source/destination tensor have different slice sizes (%ld vs %ld)",
703 dstSliceSize, srcSliceSize);
704
705 if (mismatch) {
706 TORCH_WARN_ONCE(
707 "Warning: source/destination slices have same size but different "
708 "shape for an index operation. This behavior is deprecated.\n");
709 }
710
711 return dstSliceSize;
712 }
713
714 // We prefer this kernel to avoid reloading index points if the number
715 // of indices is a small number.
716 // This kernel in fact works for all choices of problem size, but if
717 // the number of indices chosen is large, then the
718 // indexFuncLargeIndex kernel is a better choice to increase
719 // parallelism.
720 template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
721 typename func_t>
indexFuncSmallIndex(cuda::detail::TensorInfo<T,IndexType> dst,cuda::detail::TensorInfo<const T,IndexType> src,cuda::detail::TensorInfo<const IndicesType,IndexType> indices,int dstAddDim,int srcAddDim,IndexType innerSize,int64_t dstAddDimSize,int64_t dstNumel,const func_t & op,T alpha)722 __global__ void indexFuncSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
723 cuda::detail::TensorInfo<const T, IndexType> src,
724 cuda::detail::TensorInfo<const IndicesType, IndexType> indices,
725 int dstAddDim,
726 int srcAddDim,
727 IndexType innerSize,
728 int64_t dstAddDimSize,
729 int64_t dstNumel,
730 const func_t& op,
731 T alpha) {
732 // In order to avoid reloading the index that we are copying, load
733 // it once to handle all of the points that are being selected, so
734 // it can be reused as much as possible. This kernel is chosen when
735 // this is a good choice (small number of chosen indices), since
736 // re-accessing indices in addition to src elements can be slow.
737 for (IndexType srcIndex = 0; srcIndex < indices.sizes[0]; ++srcIndex) {
738 // Lua indices begin at 1
739 IndexType dstIndex =
740 indices.data[cuda::detail::IndexToOffset<const IndicesType, IndexType, IdxDim>::get(srcIndex, indices)];
741 CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
742
743 // We stride over the output ignoring the indexed dimension
744 // (innerSize), whose offset calculation is handled differently
745 for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
746 linearIndex < innerSize;
747 linearIndex += gridDim.x * blockDim.x) {
748 IndexType dstOffset =
749 cuda::detail::IndexToOffset<T, IndexType, DstDim>::get(linearIndex, dst);
750 dstOffset += dstIndex * dst.strides[dstAddDim];
751
752 IndexType srcOffset =
753 cuda::detail::IndexToOffset<const T, IndexType, SrcDim>::get(linearIndex, src);
754 srcOffset += srcIndex * src.strides[srcAddDim];
755
756 T val = src.data[srcOffset] * alpha;
757 op(dst.data, dstOffset, dstNumel, &val);
758 }
759
760 }
761 }
762
763 // We prefer this kernel to balance parallelism across index points,
764 // if there are a large number of indices.
765 // This kernel in fact works for all choices of problem size, but if
766 // the number of indices chosen is small, then the
767 // indexFuncSmallIndex kernel is a better choice to reduce memory
768 // accesses.
769 template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
770 bool IndexIsMajor, typename func_t>
indexFuncLargeIndex(cuda::detail::TensorInfo<T,IndexType> dst,cuda::detail::TensorInfo<const T,IndexType> src,cuda::detail::TensorInfo<const IndicesType,IndexType> indices,int dstAddDim,int srcAddDim,IndexType totalSize,IndexType innerSize,int64_t dstAddDimSize,int64_t dstNumel,const func_t & op,T alpha)771 __global__ void indexFuncLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
772 cuda::detail::TensorInfo<const T, IndexType> src,
773 cuda::detail::TensorInfo<const IndicesType, IndexType> indices,
774 int dstAddDim,
775 int srcAddDim,
776 IndexType totalSize,
777 IndexType innerSize,
778 int64_t dstAddDimSize,
779 int64_t dstNumel,
780 const func_t& op,
781 T alpha) {
782 // We stride over the output including the indexed dimension
783 // (totalSize), and calculate the destination index point based on that
784 for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
785 linearIndex < totalSize;
786 linearIndex += gridDim.x * blockDim.x) {
787 IndexType srcIndex, elementInSlice;
788 if (IndexIsMajor) {
789 srcIndex = linearIndex / innerSize;
790 elementInSlice = linearIndex % innerSize;
791 }
792 else {
793 elementInSlice = linearIndex / innerSize;
794 srcIndex = linearIndex % innerSize;
795 }
796
797 // Lua indices begin at 1
798 IndexType dstIndex =
799 indices.data[cuda::detail::IndexToOffset<const IndicesType, IndexType, IdxDim>::get(srcIndex, indices)];
800 CUDA_KERNEL_ASSERT(dstIndex < dstAddDimSize);
801
802 IndexType dstOffset =
803 cuda::detail::IndexToOffset<T, IndexType, DstDim>::get(elementInSlice, dst);
804 dstOffset += dstIndex * dst.strides[dstAddDim];
805
806 IndexType srcOffset =
807 cuda::detail::IndexToOffset<const T, IndexType, SrcDim>::get(elementInSlice, src);
808 srcOffset += srcIndex * src.strides[srcAddDim];
809
810 T val = src.data[srcOffset] * alpha;
811 op(dst.data, dstOffset, dstNumel, &val);
812 }
813 }
814
815 // Compare the stride between adjacent slices (sliceStride) with strides in the
816 // other dimensions (i.e., strides *inside* each slice).
817 //
818 // - Returns true if some dimension inside the slice has lower stride than
819 // sliceStride. The simplest example is a 2-D contiguous tensor with sliceDim
820 // == 0 (that is, each slice is a row).
821 //
822 // In this case, we choose the CUDA kernel that processes the data in
823 // "index-major order". For example, if thread count equals slice size, then
824 // all threads process slice #0 in lockstep, and then slice #1, and so on.
825 //
826 // - Otherwise (i.e., sliceStride has the lowest value), this function returns
827 // false. The simplest example is a 2-D contiguous tensor with sliceDim == 1
828 // (each slice is a column).
829 //
830 // In this case, we choose the CUDA kernel that processes the data in
831 // "elementInSlice-major order". For example, each thread can process element
832 // #0 of every slice, and then element #1 of every slice, and so on.
833 template <typename scalar_t>
indexShouldBeMajor(cuda::detail::TensorInfo<scalar_t,unsigned int> & info,int sliceDim)834 bool indexShouldBeMajor(cuda::detail::TensorInfo<scalar_t, unsigned int> &info,
835 int sliceDim)
836 {
837 // The stride between adjacent slices (e.g., between element #0 of slice #100
838 // and element #0 of slice #101).
839 unsigned int sliceStride = info.strides[sliceDim];
840
841 for (const auto i: c10::irange(info.dims)) {
842 if (i != sliceDim && info.sizes[i] > 1 && info.strides[i] < sliceStride) {
843 return true;
844 }
845 }
846
847 return false;
848 }
849
index_add_cuda_impl(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source,const Scalar & alpha,const Tensor & result)850 void index_add_cuda_impl(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) {
851 if (!result.is_same(self)) {
852 result.copy_(self);
853 }
854
855 // Scalars are treated as 1-d tensor
856 const Tensor self_ = (result.dim() == 0) ? result.view(1) : result;
857 const Tensor source_ = (source.dim() == 0) ? source.view(1) : source;
858
859 TORCH_CHECK(result.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims");
860 TORCH_CHECK(source.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims" );
861 TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims");
862
863 if (globalContext().deterministicAlgorithms()){
864 torch::List<std::optional<Tensor>> indices;
865 indices.reserve(dim + 1);
866 for (const auto i: c10::irange(dim)) {
867 indices.emplace_back();
868 }
869 indices.emplace_back(index.to(at::kLong));
870 result.index_put_(indices, source * alpha, true);
871 return;
872 }
873
874 // The `source` is partitioned into two parts:
875 // -the size of each slice we are indexing, which is the
876 // total size of the tensor ignoring dimension `dim`;
877 // -the number of index we are choosing, which is the total size
878 // of the tensor `index`.
879 const uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
880 const uint64_t sourceTotalSize = source.numel();
881 const uint64_t selfAddDimSize = self_.size(dim);
882 const uint64_t numIndex = index.numel();
883 const uint64_t selfNumel = self_.numel();
884
885 if (sliceSize == 0) {
886 return;
887 }
888 const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
889 const bool indContig = index.is_contiguous();
890
891 const int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
892
893 #define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \
894 indexFuncSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
895 <<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
896 selfInfo, sourceInfo, indexInfo, \
897 selfAddDim, sourceAddDim, sliceSize, selfAddDimSize, \
898 selfNumel, reduce_add, alpha_value); \
899 C10_CUDA_KERNEL_LAUNCH_CHECK();
900
901 #define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \
902 SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \
903 indexFuncLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, \
904 SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR> \
905 <<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
906 selfInfo, sourceInfo, indexInfo, \
907 selfAddDim, sourceAddDim, sourceTotalSize, \
908 (IDX_IS_MAJOR) ? sliceSize : numIndex, \
909 selfAddDimSize, selfNumel, reduce_add, alpha_value); \
910 C10_CUDA_KERNEL_LAUNCH_CHECK();
911
912 const dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8)));
913 const dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
914
915 const dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8)));
916 const dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128));
917
918 if (cuda::detail::canUse32BitIndexMath(result) &&
919 cuda::detail::canUse32BitIndexMath(source) &&
920 cuda::detail::canUse32BitIndexMath(index)) {
921 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, at::ScalarType::ComplexHalf, result.scalar_type(), "index_add", [&] {
922 cuda::detail::TensorInfo<scalar_t, unsigned int> selfInfo =
923 cuda::detail::getTensorInfo<scalar_t, unsigned int>(self_);
924 const int selfAddDim = selfInfo.collapseDims(dim);
925 selfInfo.reduceDim(selfAddDim);
926 const auto alpha_value = alpha.to<scalar_t>();
927 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () {
928 auto sourceInfo =
929 cuda::detail::getTensorInfo<const scalar_t, unsigned int>(source_);
930 const int sourceAddDim = sourceInfo.collapseDims(dim);
931 sourceInfo.reduceDim(sourceAddDim);
932
933 auto indexInfo =
934 cuda::detail::getTensorInfo<const index_t, unsigned int>(index);
935 indexInfo.collapseDims();
936
937 // A reasonable choice for when to have each thread iterate over
938 // index to choose
939 if (numIndex <= 16) {
940 if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
941 SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
942 } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
943 SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
944 } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
945 SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
946 } else {
947 SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
948 }
949 } else {
950 const bool indexIsMajor = indexShouldBeMajor(selfInfo, selfAddDim);
951
952 if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
953 LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
954 } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
955 if (indexIsMajor) {
956 LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
957 } else {
958 LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
959 }
960 } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
961 if (indexIsMajor) {
962 LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
963 } else {
964 LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
965 }
966 } else {
967 LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
968 }
969 }
970 });
971 });
972 } else {
973 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_add", [&] {
974 cuda::detail::TensorInfo<scalar_t, uint64_t> selfInfo =
975 cuda::detail::getTensorInfo<scalar_t, uint64_t>(self_);
976 const int selfAddDim = selfInfo.collapseDims(dim);
977 selfInfo.reduceDim(selfAddDim);
978 const auto alpha_value = alpha.to<scalar_t>();
979
980 cuda::detail::TensorInfo<const scalar_t, uint64_t> sourceInfo =
981 cuda::detail::getTensorInfo<const scalar_t, uint64_t>(source_);
982 const int sourceAddDim = sourceInfo.collapseDims(dim);
983 sourceInfo.reduceDim(sourceAddDim);
984
985 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_add_cuda_", [&] () {
986 cuda::detail::TensorInfo<const index_t, uint64_t> indexInfo =
987 cuda::detail::getTensorInfo<const index_t, uint64_t>(index);
988 indexInfo.collapseDims();
989
990 LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
991 });
992 });
993 }
994
995 #undef SMALL_INDEX
996 #undef LARGE_INDEX
997 }
998
999 template <typename func_t>
index_reduce_func_cuda_impl(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & source,bool include_self,const ReductionType & reduce,const func_t & reduce_func,const Tensor & result)1000 void index_reduce_func_cuda_impl(
1001 const Tensor& self,
1002 int64_t dim,
1003 const Tensor& index,
1004 const Tensor& source,
1005 bool include_self,
1006 const ReductionType& reduce,
1007 const func_t& reduce_func,
1008 const Tensor& result) {
1009 globalContext().alertNotDeterministic("index_reduce_cuda");
1010
1011 if (!result.is_same(self)) result.copy_(self);
1012
1013 // Scalars are treated as 1-d tensor
1014 Tensor self_ = (result.dim() == 0) ? result.view(1) : result;
1015 Tensor source_ = (source.dim() == 0) ? source.view(1) : source;
1016
1017 TORCH_CHECK(result.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims");
1018 TORCH_CHECK(source.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims" );
1019 TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, "tensor has too many (>", MAX_TENSORINFO_DIMS, ") dims");
1020
1021 if (!include_self) {
1022 AT_DISPATCH_ALL_TYPES_AND2(
1023 at::ScalarType::Half, at::ScalarType::BFloat16,
1024 self.scalar_type(), "index_reduce_func_cuda_exclude_input_init", [&] {
1025 scalar_t init_val;
1026 switch (reduce) {
1027 case ReductionType::PROD:
1028 init_val = (scalar_t)1;
1029 break;
1030 case ReductionType::MAX:
1031 init_val = std::numeric_limits<scalar_t>::has_infinity ? -std::numeric_limits<scalar_t>::infinity()
1032 : std::numeric_limits<scalar_t>::lowest();
1033 break;
1034 case ReductionType::MIN:
1035 init_val = std::numeric_limits<scalar_t>::has_infinity ? std::numeric_limits<scalar_t>::infinity()
1036 : std::numeric_limits<scalar_t>::max();
1037 break;
1038 default:
1039 init_val = (scalar_t)0;
1040 break;
1041 }
1042 // index_fill_ requires index to be a LongTensor
1043 self_.index_fill_(dim, index.to(at::ScalarType::Long), init_val);
1044 });
1045 }
1046
1047 // The `source` is partitioned into two parts:
1048 // -the size of each slice we are indexing, which is the
1049 // total size of the tensor ignoring dimension `dim`;
1050 // -the number of index we are choosing, which is the total size
1051 // of the tensor `index`.
1052 uint64_t sliceSize = getSliceSize(self_, dim, index, source_);
1053 uint64_t sourceTotalSize = source.numel();
1054 uint64_t selfReduceDimSize = self_.size(dim);
1055 uint64_t numIndex = index.numel();
1056 uint64_t selfNumel = self_.numel();
1057
1058 if (sliceSize == 0) {
1059 return;
1060 }
1061 const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1062 bool indContig = index.is_contiguous();
1063
1064 int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1065
1066 #define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM) \
1067 indexFuncSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, SELF_DIM, SOURCE_DIM, IDX_DIM> \
1068 <<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
1069 selfInfo, sourceInfo, indexInfo, \
1070 selfReduceDim, sourceReduceDim, sliceSize, selfReduceDimSize, \
1071 selfNumel, reduce_func, alpha_value); \
1072 C10_CUDA_KERNEL_LAUNCH_CHECK();
1073
1074 #define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \
1075 SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR) \
1076 indexFuncLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, \
1077 SELF_DIM, SOURCE_DIM, IDX_DIM, IDX_IS_MAJOR> \
1078 <<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
1079 selfInfo, sourceInfo, indexInfo, \
1080 selfReduceDim, sourceReduceDim, sourceTotalSize, \
1081 (IDX_IS_MAJOR) ? sliceSize : numIndex, \
1082 selfReduceDimSize, selfNumel, reduce_func, alpha_value); \
1083 C10_CUDA_KERNEL_LAUNCH_CHECK();
1084
1085 dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t)(mpc * 8)));
1086 dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
1087
1088 dim3 largeIndexGrid(std::min(ceil_div(sourceTotalSize, (uint64_t)128), (uint64_t)(mpc * 8)));
1089 dim3 largeIndexBlock(std::min(sourceTotalSize, (uint64_t)128));
1090
1091 if (cuda::detail::canUse32BitIndexMath(result) &&
1092 cuda::detail::canUse32BitIndexMath(source) &&
1093 cuda::detail::canUse32BitIndexMath(index)) {
1094 AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, result.scalar_type(), "index_reduce", [&] {
1095 cuda::detail::TensorInfo<scalar_t, unsigned int> selfInfo =
1096 cuda::detail::getTensorInfo<scalar_t, unsigned int>(self_);
1097 int selfReduceDim = selfInfo.collapseDims(dim);
1098 selfInfo.reduceDim(selfReduceDim);
1099 auto alpha_value = (scalar_t) 1;
1100 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_reduce_cuda", [&] () {
1101 auto sourceInfo =
1102 cuda::detail::getTensorInfo<const scalar_t, unsigned int>(source_);
1103 int sourceReduceDim = sourceInfo.collapseDims(dim);
1104 sourceInfo.reduceDim(sourceReduceDim);
1105
1106 auto indexInfo =
1107 cuda::detail::getTensorInfo<const index_t, unsigned int>(index);
1108 indexInfo.collapseDims();
1109
1110 // A reasonable choice for when to have each thread iterate over
1111 // index to choose
1112 if (numIndex <= 16) {
1113 if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
1114 SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
1115 } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
1116 SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
1117 } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
1118 SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
1119 } else {
1120 SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
1121 }
1122 } else {
1123 bool indexIsMajor = indexShouldBeMajor(selfInfo, selfReduceDim);
1124
1125 if (selfInfo.dims == 1 && sourceInfo.dims == 1 && indContig) {
1126 LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
1127 } else if (selfInfo.dims == 2 && sourceInfo.dims == 2 && indContig) {
1128 if (indexIsMajor) {
1129 LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
1130 } else {
1131 LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
1132 }
1133 } else if (selfInfo.dims == 3 && sourceInfo.dims == 3 && indContig) {
1134 if (indexIsMajor) {
1135 LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
1136 } else {
1137 LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
1138 }
1139 } else {
1140 LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
1141 }
1142 }
1143 });
1144 });
1145 } else {
1146 AT_DISPATCH_ALL_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "index_reduce", [&] {
1147 cuda::detail::TensorInfo<scalar_t, uint64_t> selfInfo =
1148 cuda::detail::getTensorInfo<scalar_t, uint64_t>(self_);
1149 int selfReduceDim = selfInfo.collapseDims(dim);
1150 selfInfo.reduceDim(selfReduceDim);
1151 auto alpha_value = (scalar_t) 1;
1152
1153 cuda::detail::TensorInfo<const scalar_t, uint64_t> sourceInfo =
1154 cuda::detail::getTensorInfo<const scalar_t, uint64_t>(source_);
1155 int sourceReduceDim = sourceInfo.collapseDims(dim);
1156 sourceInfo.reduceDim(sourceReduceDim);
1157
1158 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_reduce_cuda", [&] () {
1159 cuda::detail::TensorInfo<const index_t, uint64_t> indexInfo =
1160 cuda::detail::getTensorInfo<const index_t, uint64_t>(index);
1161 indexInfo.collapseDims();
1162
1163 LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
1164 });
1165 });
1166 }
1167
1168 #undef SMALL_INDEX
1169 #undef LARGE_INDEX
1170 }
1171
TORCH_IMPL_FUNC(index_add_cuda_out)1172 TORCH_IMPL_FUNC(index_add_cuda_out)
1173 (const Tensor& self, int64_t dim, const Tensor& index, const Tensor& source, const Scalar& alpha, const Tensor& result) {
1174 index_add_cuda_impl(self, dim, index, source, alpha, result);
1175 }
1176
TORCH_IMPL_FUNC(index_reduce_cuda_out)1177 TORCH_IMPL_FUNC(index_reduce_cuda_out)
1178 (const Tensor& self,
1179 int64_t dim,
1180 const Tensor& index,
1181 const Tensor& source,
1182 const c10::string_view reduce,
1183 bool include_self,
1184 const Tensor& result) {
1185 TORCH_WARN_ONCE("index_reduce() is in beta and the API may change at any time.");
1186
1187 if (reduce == "prod") {
1188 index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::PROD, reduce_multiply, result);
1189 } else if (reduce == "mean") {
1190 index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MEAN, reduce_add, result);
1191 auto counts = include_self ? at::ones_like(result) : at::zeros_like(result);
1192 counts.index_add_(dim, index, at::ones_like(source));
1193 counts.masked_fill_(counts == 0, 1);
1194 if (result.is_floating_point() || result.is_complex()) {
1195 result.div_(counts);
1196 } else {
1197 result.div_(counts, "floor");
1198 }
1199 } else if (reduce == "amax") {
1200 index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MAX, reduce_maximum, result);
1201 } else if (reduce == "amin") {
1202 index_reduce_func_cuda_impl(self, dim, index, source, include_self, ReductionType::MIN, reduce_minimum, result);
1203 } else {
1204 TORCH_CHECK(false, "reduce argument must be either prod, mean, amax or amin, got ", reduce, ".");
1205 }
1206 }
1207
1208 namespace {
1209 // We prefer this kernel to avoid reloading index points if the number
1210 // of indices is a small number.
1211 // This kernel in fact works for all choices of problem size, but if
1212 // the number of indices chosen is large, then the
1213 // indexSelectLargeIndex kernel is a better choice to increase
1214 // parallelism.
1215 template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim>
indexSelectSmallIndex(cuda::detail::TensorInfo<T,IndexType> dst,cuda::detail::TensorInfo<const T,IndexType> src,cuda::detail::TensorInfo<const IndicesType,IndexType> indices,int dstSelectDim,int srcSelectDim,IndexType innerSize,int64_t srcSelectDimSize)1216 __global__ void indexSelectSmallIndex(cuda::detail::TensorInfo<T, IndexType> dst,
1217 cuda::detail::TensorInfo<const T, IndexType> src,
1218 cuda::detail::TensorInfo<const IndicesType, IndexType> indices,
1219 int dstSelectDim,
1220 int srcSelectDim,
1221 IndexType innerSize,
1222 int64_t srcSelectDimSize) {
1223 // In order to avoid reloading the index that we are copying, load
1224 // it once to handle all of the points that are being selected, so
1225 // it can be reused as much as possible. This kernel is chosen when
1226 // this is a good choice (small number of chosen indices), since
1227 // re-accessing indices in addition to src elements can be slow.
1228 for (IndexType dstIndex = 0; dstIndex < indices.sizes[0]; ++dstIndex) {
1229 IndexType srcIndex =
1230 indices.data[cuda::detail::IndexToOffset<const IndicesType, IndexType, IdxDim>::get(dstIndex, indices)];
1231 CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
1232
1233 // We stride over the output ignoring the indexed dimension
1234 // (innerSize), whose offset calculation is handled differently
1235 for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
1236 linearIndex < innerSize;
1237 linearIndex += gridDim.x * blockDim.x) {
1238 IndexType dstOffset =
1239 cuda::detail::IndexToOffset<T, IndexType, DstDim>::get(linearIndex, dst);
1240 dstOffset += dstIndex * dst.strides[dstSelectDim];
1241
1242 IndexType srcOffset =
1243 cuda::detail::IndexToOffset<const T, IndexType, SrcDim>::get(linearIndex, src);
1244 srcOffset += srcIndex * src.strides[srcSelectDim];
1245
1246 dst.data[dstOffset] = src.data[srcOffset];
1247 }
1248 }
1249 }
1250
1251 // We prefer this kernel to balance parallelism across index points,
1252 // if there are a large number of indices.
1253 // This kernel in fact works for all choices of problem size, but if
1254 // the number of indices chosen is small, then the
1255 // indexSelectSmallIndex kernel is a better choice to reduce memory
1256 // accesses.
1257 template <typename T, typename IndicesType, typename IndexType, int DstDim, int SrcDim, int IdxDim,
1258 bool IndexIsMajor>
indexSelectLargeIndex(cuda::detail::TensorInfo<T,IndexType> dst,cuda::detail::TensorInfo<const T,IndexType> src,cuda::detail::TensorInfo<const IndicesType,IndexType> indices,int dstSelectDim,int srcSelectDim,IndexType totalSize,IndexType innerSize,int64_t srcSelectDimSize)1259 __global__ void indexSelectLargeIndex(cuda::detail::TensorInfo<T, IndexType> dst,
1260 cuda::detail::TensorInfo<const T, IndexType> src,
1261 cuda::detail::TensorInfo<const IndicesType, IndexType> indices,
1262 int dstSelectDim,
1263 int srcSelectDim,
1264 IndexType totalSize,
1265 IndexType innerSize,
1266 int64_t srcSelectDimSize) {
1267 // We stride over the output including the indexed dimension
1268 // (totalSize), and calculate the destination index point based on that
1269 for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x;
1270 linearIndex < totalSize;
1271 linearIndex += gridDim.x * blockDim.x) {
1272 IndexType dstIndex, elementInSlice;
1273 if (IndexIsMajor) {
1274 dstIndex = linearIndex / innerSize;
1275 elementInSlice = linearIndex % innerSize;
1276 }
1277 else {
1278 elementInSlice = linearIndex / innerSize;
1279 dstIndex = linearIndex % innerSize;
1280 }
1281
1282 IndexType srcIndex =
1283 indices.data[cuda::detail::IndexToOffset<const IndicesType, IndexType, IdxDim>::get(dstIndex, indices)];
1284 CUDA_KERNEL_ASSERT(srcIndex < srcSelectDimSize);
1285
1286 IndexType dstOffset =
1287 cuda::detail::IndexToOffset<T, IndexType, DstDim>::get(elementInSlice, dst);
1288 dstOffset += dstIndex * dst.strides[dstSelectDim];
1289
1290 IndexType srcOffset =
1291 cuda::detail::IndexToOffset<const T, IndexType, SrcDim>::get(elementInSlice, src);
1292 srcOffset += srcIndex * src.strides[srcSelectDim];
1293
1294 dst.data[dstOffset] = src.data[srcOffset];
1295 }
1296 }
1297
1298 namespace {
1299
1300 // When using a 0-dim scalar tensor, we need the legacy (THC) semantics of
1301 // TensorInfo: Pretend that the scalar tensor is in fact a one-element vector.
1302 template <typename T, typename IndexType>
1303 cuda::detail::TensorInfo<T, IndexType>
tensorInfoLegacyIfScalar(cuda::detail::TensorInfo<T,IndexType> ti)1304 tensorInfoLegacyIfScalar(cuda::detail::TensorInfo<T, IndexType> ti) {
1305 if (ti.dims == 0) {
1306 ti.dims = 1;
1307 ti.sizes[0] = 1;
1308 ti.strides[0] = 1;
1309 }
1310 return ti;
1311 }
1312
1313 }
1314
1315 template <typename scalar_t>
index_select_out_cuda_impl(Tensor & out,const Tensor & self,long dim,const Tensor & index)1316 void index_select_out_cuda_impl(
1317 Tensor& out,
1318 const Tensor& self,
1319 long dim,
1320 const Tensor& index) {
1321 uint64_t numIndices = index.numel();
1322 uint64_t selfDims = self.dim() == 0 ? 1 : self.dim();
1323
1324 const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
1325
1326 TORCH_CHECK(
1327 index.dim() <= 1, "Index is supposed to be an empty tensor or a vector");
1328 TORCH_CHECK(
1329 !(self.dim() == 0 && numIndices != 1), "index_select(): Index to scalar can have only 1 value, got ", numIndices, " value(s)");
1330 TORCH_CHECK(dim < selfDims, "Indexing dim is out of bounds");
1331
1332 std::vector<int64_t> newSize = self.sizes().vec();
1333 if (self.dim() > 0) {
1334 newSize[dim] = numIndices;
1335 }
1336
1337 if (self.is_quantized()){
1338 out = at::empty_quantized(newSize, out);
1339 } else {
1340 at::native::resize_output(out, newSize);
1341 }
1342
1343 uint64_t outTotalSize = out.numel();
1344 if (outTotalSize == 0) {
1345 return;
1346 }
1347
1348 bool indContig = index.is_contiguous();
1349
1350 // The `self` is partitioned into two parts:
1351 // -the size of each slice we are indexing, which is the
1352 // total size of the tensor ignoring dimension `dim`;
1353 // -the number of indices we are choosing, which is the total size
1354 // of the tensor `indices`.
1355 uint64_t selfSelectDimSize = self.dim() == 0 ? 1 : self.size(dim);
1356 uint64_t sliceSize = outTotalSize / numIndices;
1357
1358 int mpc = at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
1359
1360 #define SMALL_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM) \
1361 indexSelectSmallIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, DST_DIM, SRC_DIM, IDX_DIM> \
1362 <<<smallIndexGrid, smallIndexBlock, 0, stream>>>( \
1363 outInfo, selfInfo, indicesInfo, \
1364 outSelectDim, selfSelectDim, static_cast<TYPE>(sliceSize), \
1365 selfSelectDimSize); \
1366 C10_CUDA_KERNEL_LAUNCH_CHECK();
1367
1368 #define LARGE_INDEX(TENSOR_TYPE, INDICES_TYPE, TYPE, \
1369 DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR) \
1370 indexSelectLargeIndex<TENSOR_TYPE, INDICES_TYPE, TYPE, \
1371 DST_DIM, SRC_DIM, IDX_DIM, IDX_IS_MAJOR> \
1372 <<<largeIndexGrid, largeIndexBlock, 0, stream>>>( \
1373 outInfo, selfInfo, indicesInfo, \
1374 outSelectDim, selfSelectDim, static_cast<TYPE>(outTotalSize), \
1375 static_cast<TYPE>((IDX_IS_MAJOR) ? sliceSize : numIndices), \
1376 selfSelectDimSize); \
1377 C10_CUDA_KERNEL_LAUNCH_CHECK();
1378
1379 dim3 smallIndexGrid(std::min(ceil_div(sliceSize, (uint64_t)128), (uint64_t) (mpc * 8)));
1380 dim3 smallIndexBlock(std::min(sliceSize, (uint64_t)128));
1381
1382 dim3 largeIndexGrid(std::min(ceil_div(outTotalSize, (uint64_t)128), (uint64_t) (mpc * 8)));
1383 // for issue https://github.com/pytorch/pytorch/issues/130806 there are two problems
1384 // 1: ptrdiff_t was used but it is signed int, outTotalSize of 2147483648 can cause overflow
1385 // 2: On ROCm, std::min -> ::min did not work as expected on when outTotalSize>=2147483648
1386 dim3 largeIndexBlock( (outTotalSize < 128) ? outTotalSize : 128 );
1387 if (cuda::detail::canUse32BitIndexMath(out) &&
1388 cuda::detail::canUse32BitIndexMath(self) &&
1389 cuda::detail::canUse32BitIndexMath(index)) {
1390 auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, unsigned int>(out));
1391 int outSelectDim = outInfo.collapseDims(dim);
1392 outInfo.reduceDim(outSelectDim);
1393
1394 auto selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<const scalar_t, unsigned int>(self));
1395 int selfSelectDim = selfInfo.collapseDims(dim);
1396 selfInfo.reduceDim(selfSelectDim);
1397
1398 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () {
1399 auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<const index_t, unsigned int>(index));
1400 indicesInfo.collapseDims();
1401
1402 // A reasonable choice for when to have each thread iterate over
1403 // indices to choose
1404 if (numIndices <= 16) {
1405 if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
1406 SMALL_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2);
1407 } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
1408 SMALL_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2);
1409 } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
1410 SMALL_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2);
1411 } else {
1412 SMALL_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1);
1413 }
1414 } else {
1415 bool indexIsMajor = indexShouldBeMajor(outInfo, outSelectDim);
1416
1417 if (outInfo.dims == 1 && selfInfo.dims == 1 && indContig) {
1418 LARGE_INDEX(scalar_t, index_t, unsigned int, 1, 1, -2, true);
1419 } else if (outInfo.dims == 2 && selfInfo.dims == 2 && indContig) {
1420 if (indexIsMajor) {
1421 LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, true);
1422 } else {
1423 LARGE_INDEX(scalar_t, index_t, unsigned int, 2, 2, -2, false);
1424 }
1425 } else if (outInfo.dims == 3 && selfInfo.dims == 3 && indContig) {
1426 if (indexIsMajor) {
1427 LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, true);
1428 } else {
1429 LARGE_INDEX(scalar_t, index_t, unsigned int, 3, 3, -2, false);
1430 }
1431 } else {
1432 LARGE_INDEX(scalar_t, index_t, unsigned int, -1, -1, -1, true);
1433 }
1434 }
1435 });
1436 } else {
1437 auto outInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<scalar_t, uint64_t>(out));
1438 int outSelectDim = outInfo.collapseDims(dim);
1439 outInfo.reduceDim(outSelectDim);
1440
1441 auto selfInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<const scalar_t, uint64_t>(self));
1442 int selfSelectDim = selfInfo.collapseDims(dim);
1443 selfInfo.reduceDim(selfSelectDim);
1444 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_out_cuda_impl", [&] () {
1445 auto indicesInfo = tensorInfoLegacyIfScalar(cuda::detail::getTensorInfo<const index_t, uint64_t>(index));
1446 indicesInfo.collapseDims();
1447
1448 LARGE_INDEX(scalar_t, index_t, uint64_t, -1, -1, -1, true);
1449 });
1450 }
1451 #undef SMALL_INDEX
1452 #undef LARGE_INDEX
1453 }
1454 } // anonymous namespace
1455
index_select_out_cuda(const Tensor & self,int64_t dim,const Tensor & index,Tensor & out)1456 Tensor& index_select_out_cuda(
1457 const Tensor& self,
1458 int64_t dim,
1459 const Tensor& index,
1460 Tensor& out) {
1461 static constexpr string_view DIM_WARNING =
1462 "Tensor too large or too many (> 25) dimensions";
1463 TORCH_CHECK(
1464 at::cuda::check_device({out, self, index}),
1465 "Input, output and indices must be on the current device");
1466 at::assert_no_internal_overlap(out);
1467 at::assert_no_overlap(out, self);
1468 at::assert_no_overlap(out, index);
1469
1470 dim = at::maybe_wrap_dim(dim, self);
1471 TORCH_CHECK(self.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING);
1472 TORCH_CHECK(index.dim() <= MAX_TENSORINFO_DIMS, DIM_WARNING);
1473 if (self.is_quantized()){
1474 TORCH_CHECK(
1475 self.qscheme() == kPerTensorAffine,
1476 "Only per_tensor quantized quantized tensors are supported by index_select.")
1477 AT_DISPATCH_QINT_TYPES(out.scalar_type(), "index_select_quant_cuda", [&] {
1478 index_select_out_cuda_impl<scalar_t>(out, self, dim, index);
1479 });
1480 } else {
1481 AT_DISPATCH_V2(
1482 out.scalar_type(),
1483 "index_select_cuda",
1484 AT_WRAP([&] { index_select_out_cuda_impl<scalar_t>(out, self, dim, index); }),
1485 AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES),
1486 kComplexHalf,
1487 kHalf,
1488 kBool,
1489 kBFloat16
1490 );
1491 }
1492
1493 return out;
1494 }
1495
index_select_cuda(const Tensor & self,int64_t dim,const Tensor & index)1496 Tensor index_select_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
1497 Tensor out = at::empty({0}, self.options());
1498 at::native::index_select_out_cuda(self, dim, index, out);
1499 return out;
1500 }
1501
index_select_quantized_cuda(const Tensor & self,int64_t dim,const Tensor & index)1502 Tensor index_select_quantized_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
1503 TORCH_CHECK(
1504 self.qscheme() == kPerTensorAffine,
1505 "Only per_tensor quantized quantized tensors are supported by index_select.")
1506 Tensor out = at::empty_quantized({0}, self);
1507 at::native::index_select_out_cuda(self, dim, index, out);
1508 return out;
1509 }
1510
1511 namespace {
1512
masked_fill_kernel(TensorIterator & iter,const Scalar & value)1513 void masked_fill_kernel(TensorIterator& iter, const Scalar& value) {
1514 AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
1515 kBool, kHalf, kBFloat16, kComplexHalf, iter.common_dtype(), "masked_fill_", [&]() {
1516 const auto value_ = value.to<scalar_t>();
1517 gpu_kernel(
1518 iter, [value_] GPU_LAMBDA(scalar_t self, bool mask) -> scalar_t {
1519 if (mask) {
1520 return value_;
1521 }
1522 return self;
1523 });
1524 });
1525 }
1526
1527 template <typename scalar_t>
cuda_masked_fill_kernel_quantized(TensorIterator & iter,scalar_t quantized_val)1528 void cuda_masked_fill_kernel_quantized(TensorIterator& iter, scalar_t quantized_val) {
1529 gpu_kernel(
1530 iter, [quantized_val] GPU_LAMBDA(scalar_t self, bool mask) -> scalar_t {
1531 if (mask) {
1532 return quantized_val;
1533 }
1534 return self;
1535 });
1536 }
1537
masked_fill_kernel_quantized(TensorIterator & iter,const Scalar & value,double scale,int zero_point)1538 void masked_fill_kernel_quantized(TensorIterator& iter, const Scalar& value, double scale, int zero_point) {
1539 TORCH_CHECK(iter.input_dtype(1) == at::ScalarType::Bool, "masked_fill only supports boolean masks, ",
1540 "but got dtype ", iter.input_dtype(1));
1541 AT_DISPATCH_QINT_TYPES(
1542 iter.common_dtype(), "masked_fill_", [&]() {
1543 float float_val = value.to<float>();
1544 const auto quantized_val = quantize_val<scalar_t>(scale, zero_point, float_val);
1545
1546 cuda_masked_fill_kernel_quantized<scalar_t>(iter, quantized_val);
1547 });
1548 }
1549
1550 REGISTER_CUDA_DISPATCH(masked_fill_kernel_quantized_stub, &masked_fill_kernel_quantized);
1551
1552 } // anonymous namespace
1553
masked_fill__cuda(Tensor & self,const Tensor & mask,const Scalar & value)1554 Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Scalar& value) {
1555 TORCH_CHECK(self.device() == mask.device(), "expected self and mask to be on the same device, but got mask on ",
1556 mask.device(), " and self on ", self.device());
1557 TORCH_CHECK(mask.scalar_type() == kBool,
1558 "masked_fill only supports boolean masks, but got dtype ", mask.scalar_type());
1559 auto maybe_outnames = namedinference::broadcast_to_outnames(self, mask, "masked_fill_");
1560 if (at::has_internal_overlap(self) == MemOverlap::Yes) {
1561 TORCH_WARN(
1562 "Use of masked_fill_ on expanded tensors is deprecated. "
1563 "Please clone() the tensor before performing this operation. "
1564 "This also applies to advanced indexing e.g. tensor[mask] = scalar");
1565 }
1566 at::assert_no_partial_overlap(self, mask);
1567
1568 c10::MaybeOwned<Tensor> b_mask = expand_inplace(self, mask, "masked_fill_");
1569
1570 auto iter = TensorIteratorConfig()
1571 .set_check_mem_overlap(false)
1572 .check_all_same_dtype(false)
1573 .resize_outputs(false)
1574 .add_output(self)
1575 .add_const_input(self)
1576 .add_const_input(*b_mask)
1577 .build();
1578
1579 masked_fill_kernel(iter, value);
1580 namedinference::propagate_names_if_nonempty(self, maybe_outnames);
1581 return self;
1582 }
1583
masked_fill__cuda(Tensor & self,const Tensor & mask,const Tensor & value)1584 Tensor & masked_fill__cuda(Tensor& self, const Tensor & mask, const Tensor & value) {
1585 TORCH_CHECK(value.dim() == 0, "masked_fill_ only supports a 0-dimensional value tensor, but got tensor "
1586 "with ", value.dim(), " dimension(s).");
1587 // We hit this function if either of the input tensor lives on CUDA.
1588 // It is ok, if `value` is `CPU` tensor but we should not allow `self` or
1589 // `mask` to be CPU tensor. Check for `self` and `mask` being on same device
1590 // exists in `masked_fill__cuda` (Scalar version).
1591 TORCH_CHECK(!self.device().is_cpu(), "masked_fill_: Expected inputs to be on same device")
1592 return masked_fill__cuda(self, mask, value.item());
1593 }
1594
1595 namespace {
1596
1597 // ForwardIt: only legacy random access iterator is supported.
1598 template<class ForwardIt, class T, bool is_lower = true>
1599 static __host__ __device__ __forceinline__
find_bound(ForwardIt first,ForwardIt last,const T & value)1600 ForwardIt find_bound(ForwardIt first, ForwardIt last, const T& value) {
1601 ForwardIt it;
1602 typename std::iterator_traits<ForwardIt>::difference_type count, step;
1603 // NOTE: std::distance(first, last) compiles but produces wrong results here,
1604 // so only legacy random access iterators are safe in this code.
1605 count = last - first;
1606
1607 while (count > 0) {
1608 it = first;
1609 step = count / 2;
1610 // avoiding std::advance(it, step),
1611 // although it does work unlike std::distance
1612 it += step;
1613 if (is_lower ? *it < value : value >= *it) {
1614 first = ++it;
1615 count -= step + 1;
1616 }
1617 else {
1618 count = step;
1619 }
1620 }
1621 return first;
1622 }
1623
1624 }
1625
index_select_sparse_cuda(const Tensor & self,int64_t dim,const Tensor & index)1626 Tensor index_select_sparse_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
1627 const auto ndim = self.dim();
1628 TORCH_CHECK_INDEX(ndim, "index_select() cannot be applied to a 0-dim tensor.");
1629 TORCH_CHECK_INDEX(
1630 index.dim() == 1 && index.dtype() == at::kLong && index.options().layout() == at::kStrided,
1631 "index_select() argument index must be 1-D strided (non-sparse) long-tensor.");
1632 dim = maybe_wrap_dim(dim, ndim);
1633 const auto size = self.size(dim);
1634 const auto sparse_dim = self.sparse_dim();
1635 const auto dense_dim = self.dense_dim();
1636 const auto indices = self._indices();
1637 const auto values = self._values();
1638 const auto nnz = values.size(0);
1639 const auto index_len = index.size(0);
1640 auto res_sizes = self.sizes().vec();
1641 res_sizes[dim] = index_len;
1642
1643 // If indexing into sparse dimensions
1644 if (dim < sparse_dim) {
1645 const auto make_output = [
1646 dim, sparse_dim, dense_dim, res_sizes, &self, &indices, &values
1647 ](
1648 const Tensor& selected_dim_indices,
1649 const Tensor& res_dim_indices
1650 ) -> Tensor {
1651 auto res_indices = indices.index_select(1, selected_dim_indices);
1652 res_indices[dim] = res_dim_indices;
1653 const auto res_values = values.index_select(0, selected_dim_indices);
1654
1655 return at::_sparse_coo_tensor_with_dims_and_tensors(
1656 sparse_dim, dense_dim, res_sizes, res_indices, res_values, self.options());
1657 };
1658
1659 // short-circuit if index is empty
1660 if (!index_len) {
1661 return make_output(index, index);
1662 }
1663
1664 const auto nneg_index = [&index, size]() -> Tensor {
1665 auto nneg_index = at::empty_like(index, at::MemoryFormat::Contiguous);
1666
1667 auto iter = TensorIteratorConfig()
1668 .add_output(nneg_index)
1669 .add_input(index)
1670 .build();
1671
1672 AT_DISPATCH_INDEX_TYPES(index.scalar_type(), "index_select_sparse_cuda", [&]() {
1673 gpu_kernel(iter, [size] GPU_LAMBDA (index_t idx) -> index_t {
1674 CUDA_KERNEL_ASSERT(idx >= -size && idx < size
1675 && "index_select(): index out of bounds");
1676 return idx < 0 ? idx + size : idx;
1677 });
1678 });
1679 return nneg_index;
1680 }();
1681
1682 const auto dim_indices = indices[dim].contiguous();
1683 const auto idx_nneg_index = at::arange(index_len, nneg_index.options());
1684 const auto idx_dim_indices = at::arange(nnz, dim_indices.options());
1685
1686 Tensor sorted_dim_indices, argsort_dim_indices;
1687 std::tie(sorted_dim_indices, argsort_dim_indices) = [&]() -> std::tuple<Tensor, Tensor> {
1688 if (dim == 0 && self.is_coalesced()) {
1689 return std::make_tuple(dim_indices, idx_dim_indices);
1690 }
1691 else {
1692 return dim_indices.sort();
1693 }
1694 }();
1695
1696 Tensor intrsc_counts_nneg_index;
1697 Tensor intrsc_first_match_nneg_index;
1698 std::tie(intrsc_counts_nneg_index, intrsc_first_match_nneg_index) = [&]() -> std::tuple<Tensor, Tensor> {
1699 auto intrsc_counts_nneg_index = at::zeros_like(nneg_index);
1700 auto intrsc_first_match_nneg_index = at::zeros_like(nneg_index);
1701
1702 auto iter = TensorIteratorConfig()
1703 .add_output(intrsc_first_match_nneg_index)
1704 .add_input(nneg_index)
1705 .add_input(idx_nneg_index)
1706 .build();
1707
1708 AT_DISPATCH_INDEX_TYPES(nneg_index.scalar_type(), "index_select_sparse_cuda", [&]() {
1709 index_t* ptr_intrsc_counts_nneg_index = intrsc_counts_nneg_index.mutable_data_ptr<index_t>();
1710 const index_t* ptr_sorted_dim_indices = sorted_dim_indices.const_data_ptr<index_t>();
1711 gpu_kernel(
1712 iter,
1713 [ptr_intrsc_counts_nneg_index, ptr_sorted_dim_indices, nnz] GPU_LAMBDA (
1714 index_t idx_val, index_t idx_idx
1715 ) -> index_t {
1716 auto* lb = find_bound<const index_t*, index_t, true>(
1717 ptr_sorted_dim_indices,
1718 ptr_sorted_dim_indices + nnz,
1719 idx_val
1720 );
1721 auto* ub = find_bound<const index_t*, index_t, false>(
1722 ptr_sorted_dim_indices,
1723 ptr_sorted_dim_indices + nnz,
1724 idx_val
1725 );
1726 const auto idx_count = ub - lb;
1727 ptr_intrsc_counts_nneg_index[idx_idx] = idx_count;
1728
1729 return lb - ptr_sorted_dim_indices;
1730 }
1731 );
1732 });
1733
1734 return std::make_tuple(intrsc_counts_nneg_index, intrsc_first_match_nneg_index);
1735 }();
1736
1737 // Unavoidable sync since the shape of the result is not known in advance
1738 auto res_len = intrsc_counts_nneg_index.sum().item<int64_t>();
1739 // Short-circuit if empty intersection
1740 if (!res_len) {
1741 auto empty_idx = at::empty({0}, nneg_index.options());
1742 return make_output(empty_idx, empty_idx);
1743 }
1744
1745 Tensor selected_dim_indices, res_dim_indices;
1746 std::tie(selected_dim_indices, res_dim_indices) = [&]() -> std::tuple<Tensor, Tensor> {
1747 auto res_dim_indices = at::empty({res_len}, nneg_index.options());
1748 auto selected_dim_indices = at::empty_like(res_dim_indices);
1749 auto selected_dim_indices_offsets = intrsc_counts_nneg_index.cumsum(0)
1750 .sub_(intrsc_counts_nneg_index);
1751
1752 // Need to have output as TensorIterator does not allow having void lambdas.
1753 auto dummy_output = at::empty({1}, dim_indices.options()).expand(IntArrayRef({index_len}));
1754 auto iter = TensorIteratorConfig()
1755 .add_output(dummy_output)
1756 // All iterations map to a single element in dummy_output by design,
1757 // hence removed output memory overlap check.
1758 .set_check_mem_overlap(false)
1759 .add_input(idx_nneg_index)
1760 .add_input(intrsc_counts_nneg_index)
1761 .add_input(selected_dim_indices_offsets)
1762 .add_input(intrsc_first_match_nneg_index)
1763 .build();
1764
1765 AT_DISPATCH_INDEX_TYPES(nneg_index.scalar_type(), "index_select_sparse_cuda", [&]() {
1766 index_t* ptr_res_dim_indices = res_dim_indices.mutable_data_ptr<index_t>();
1767 index_t* ptr_selected_dim_indices = selected_dim_indices.mutable_data_ptr<index_t>();
1768 const index_t* ptr_argsort_dim_indices = argsort_dim_indices.const_data_ptr<index_t>();
1769 gpu_kernel(
1770 iter,
1771 [ptr_res_dim_indices, ptr_selected_dim_indices, ptr_argsort_dim_indices] GPU_LAMBDA (
1772 index_t idx_idx, index_t count, index_t offset, index_t first_match
1773 ) -> index_t {
1774 index_t* __restrict__ ptr_res_dim_indices_out = ptr_res_dim_indices + offset;
1775 const index_t* __restrict__ ptr_argsort_dim_indices_in = ptr_argsort_dim_indices + first_match;
1776 index_t* __restrict__ ptr_selected_dim_indices_out = ptr_selected_dim_indices + offset;
1777 for (index_t i = 0; i < count; ++i) {
1778 *ptr_res_dim_indices_out++ = idx_idx;
1779 *ptr_selected_dim_indices_out++ = *ptr_argsort_dim_indices_in++;
1780 }
1781
1782 // A dummy return scalar for a dummy output
1783 return static_cast<index_t>(1);
1784 }
1785 );
1786 });
1787
1788 return std::make_tuple(selected_dim_indices, res_dim_indices);
1789 }();
1790
1791 return make_output(selected_dim_indices, res_dim_indices);
1792 }
1793 // If indexing into dense dimensions
1794 else {
1795 // It is sufficient to just perform `index_select` on values
1796 // if `dim` refers to dense dimensions.
1797 const auto res_values = values.index_select(dim - sparse_dim + 1, index);
1798
1799 return _sparse_coo_tensor_with_dims_and_tensors(
1800 sparse_dim, dense_dim, res_sizes, indices, res_values, self.options());
1801 }
1802 }
1803
1804
1805 } // at::native
1806