xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/TensorAdvancedIndexing.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/MemoryOverlap.h>
7 
8 #include <ATen/native/ScatterGatherChecks.h>
9 #include <ATen/native/ReduceOpsUtils.h>
10 #include <ATen/native/TensorIterator.h>
11 
12 #include <ATen/native/cuda/Loops.cuh>
13 #include <ATen/native/cuda/KernelUtils.cuh>
14 #include <ATen/cuda/detail/OffsetCalculator.cuh>
15 #include <ATen/cuda/Atomic.cuh>
16 #include <ATen/cuda/CUDAContext.h>
17 
18 namespace at::native {
19 
20 // Implement as functors since lambdas don't get optimized.
21 class ReduceMultiply {
22 public:
23   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const24   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
25     (void)numel; // suppress unused warning
26     gpuAtomicMul(self_data_start + index, *src_data);
27   }
28 };
29 static ReduceMultiply reduce_multiply;
30 
31 class ReduceAdd {
32 public:
33   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const34   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
35     fastAtomicAdd(self_data_start, index, numel, *src_data, true);
36   }
37 };
38 static ReduceAdd reduce_add;
39 
40 class ReduceMean {
41 public:
42   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const43   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
44     fastAtomicAdd(self_data_start, index, numel, *src_data, true);
45   }
46 };
47 static ReduceMean reduce_mean;
48 
49 class ReduceMinimum {
50 public:
51   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const52   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
53     (void)numel; // suppress unused warning
54     gpuAtomicMin(self_data_start + index, *src_data);
55   }
56 };
57 static ReduceMinimum reduce_minimum;
58 
59 class ReduceMaximum {
60 public:
61   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const62   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
63     (void)numel; // suppress unused warning
64     gpuAtomicMax(self_data_start + index, *src_data);
65   }
66 };
67 static ReduceMaximum reduce_maximum;
68 
69 class TensorAssign {
70 public:
71   template <typename scalar_t>
operator ()(scalar_t * self_data_start,int64_t index,int64_t numel,const scalar_t * src_data) const72   constexpr C10_DEVICE void operator() (scalar_t* self_data_start, int64_t index, int64_t numel, const scalar_t * src_data) const {
73     (void)numel; // suppress unused warning
74     *(self_data_start + index) = *src_data;
75   }
76 };
77 static TensorAssign tensor_assign;
78 
79 // The kernels are implemented on an opaque,
80 // self-aligned type of the correct size,
81 // to avoid redundant kernels for different types
82 // of the same size.
83 template <int N> struct alignas(N) OpaqueType { char data[N]; };
84 
85 // essentially rewritten related to legacy::launch_kernel parts
86 template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt,vt)87 C10_LAUNCH_BOUNDS_2(nt, vt)
88 __global__ void _scatter_gather_elementwise_kernel(int N, func_t f) {
89   constexpr int nv = nt * vt;
90   int idx = nv * blockIdx.x + threadIdx.x;
91 
92   #pragma unroll
93   for (int i = 0; i < vt; ++i) {
94     if (idx < N) {
95       f(idx);
96       idx += nt;
97     }
98   }
99 }
100 
101 template <int nt, int vt, typename func_t>
_launch_scatter_gather_kernel(int64_t N,const func_t & f)102 static void _launch_scatter_gather_kernel(int64_t N, const func_t& f) {
103   TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
104   if (N == 0) {
105     return;
106   }
107 
108   const dim3 block(nt);
109   const dim3 grid((N + block.x * vt - 1) / (block.x * vt));
110   const auto stream = at::cuda::getCurrentCUDAStream();
111   _scatter_gather_elementwise_kernel<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
112   C10_CUDA_KERNEL_LAUNCH_CHECK();
113 }
114 
115 
116 template <bool is_scatter_like, typename scalar_t>
117 struct _cuda_scatter_gather_internal_kernel {
118   template <typename func_t>
operator ()at::native::_cuda_scatter_gather_internal_kernel119   void operator() (
120     TensorIterator& iter,
121     int64_t index_size,
122     int64_t index_stride,
123     int64_t numel,  // Do not use `const` qualifier here as it may cause issue in cuda 11.6.x. See #75434, #75545
124     const func_t& f
125   ) {
126     if (!iter.can_use_32bit_indexing()) {
127       for (auto& sub_iter : iter.with_32bit_indexing()) {
128         _cuda_scatter_gather_internal_kernel<is_scatter_like, scalar_t>()(
129           sub_iter, index_size, index_stride, numel, f
130         );
131       }
132       return;
133     }
134 
135     char* self_ptr = (char*)iter.data_ptr(0);
136     char* src_ptr = (char*)iter.data_ptr(1);
137     char* index_ptr = (char*)iter.data_ptr(2);
138 
139     auto offset_calc = make_offset_calculator<3>(iter);
140     auto loop = [=]C10_DEVICE(int i) {
141       auto offsets = offset_calc.get(i);
142 
143       int64_t idx_dim = *(int64_t*)(index_ptr + offsets[2]);
144       CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
145         && "index out of bounds");
146 
147       f(
148         (scalar_t*)(self_ptr + offsets[0]),
149         is_scatter_like ? idx_dim * index_stride : 0,
150         numel,
151         (scalar_t*)(src_ptr + offsets[1]) + (is_scatter_like ? 0 : idx_dim * index_stride)
152       );
153     };
154 
155     _launch_scatter_gather_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
156   }
157 }; // struct _cuda_scatter_fill_internal_kernel
158 
159 template <bool is_scatter_like = true, bool cast_to_opaque = true>
160 struct cuda_scatter_gather_base_kernel {
operator ()at::native::cuda_scatter_gather_base_kernel161   void operator()(
162     const Tensor& self, int64_t dim,
163     const Tensor& index, const Tensor& src,
164     const std::string& method_name,
165     const ReduceAdd& f
166   ) {
167     at::assert_no_internal_overlap(self);
168 
169     auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
170     auto self_strides = ensure_nonempty_vec(self.strides().vec());
171     auto src_strides = ensure_nonempty_vec(src.strides().vec());
172 
173     // restride self and src such that
174     // self.shape = src.shape = index.shape
175     //
176     // restride stride[dim] such that
177     // if (is_scatter_like) self.stride[dim] = 0
178     // else src.stride[dim] = 0
179     auto self_restrided = is_scatter_like ?
180         restride_dim(self, dim, index_sizes)
181       : self.as_strided(index_sizes, self_strides);
182     auto src_restrided = is_scatter_like ?
183         src.as_strided(index_sizes, src_strides)
184       : restride_dim(src, dim, index_sizes);
185 
186     auto iter = TensorIteratorConfig()
187       .set_check_mem_overlap(false)
188       .check_all_same_dtype(false)
189       .resize_outputs(false)
190       .add_output(self_restrided)
191       .add_const_input(src_restrided)
192       .add_const_input(index)
193       .build();
194 
195     auto self_dim_stride = ensure_nonempty_stride(self, dim);
196     auto self_dim_size = ensure_nonempty_size(self, dim);
197 
198     auto src_dim_stride = ensure_nonempty_stride(src, dim);
199     auto src_dim_size = ensure_nonempty_size(src, dim);
200 
201     auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
202     auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
203 
204 
205     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
206       at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
207       iter.dtype(),
208       "cuda_scatter_gather_base_kernel_func", [&] {
209         using dtype = typename std::conditional<cast_to_opaque,
210           OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
211 
212         _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
213           iter, index_size, index_stride, self.numel(), f
214         );
215       }
216     );
217   }
218 
operator ()at::native::cuda_scatter_gather_base_kernel219   void operator()(
220     const Tensor& self, int64_t dim,
221     const Tensor& index, const Tensor& src,
222     const std::string& method_name,
223     const TensorAssign& f
224   ) {
225     at::assert_no_internal_overlap(self);
226 
227     auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
228     auto self_strides = ensure_nonempty_vec(self.strides().vec());
229     auto src_strides = ensure_nonempty_vec(src.strides().vec());
230 
231     // restride self and src such that
232     // self.shape = src.shape = index.shape
233     //
234     // restride stride[dim] such that
235     // if (is_scatter_like) self.stride[dim] = 0
236     // else src.stride[dim] = 0
237     auto self_restrided = is_scatter_like ?
238         restride_dim(self, dim, index_sizes)
239       : self.as_strided(index_sizes, self_strides);
240     auto src_restrided = is_scatter_like ?
241         src.as_strided(index_sizes, src_strides)
242       : restride_dim(src, dim, index_sizes);
243 
244     auto iter = TensorIteratorConfig()
245       .set_check_mem_overlap(false)
246       .check_all_same_dtype(false)
247       .resize_outputs(false)
248       .add_output(self_restrided)
249       .add_const_input(src_restrided)
250       .add_const_input(index)
251       .build();
252 
253     auto self_dim_stride = ensure_nonempty_stride(self, dim);
254     auto self_dim_size = ensure_nonempty_size(self, dim);
255 
256     auto src_dim_stride = ensure_nonempty_stride(src, dim);
257     auto src_dim_size = ensure_nonempty_size(src, dim);
258 
259     auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
260     auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
261 
262 
263     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
264       at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
265       iter.dtype(),
266       "cuda_scatter_gather_base_kernel_func", [&] {
267         using dtype = typename std::conditional<cast_to_opaque,
268           OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
269 
270         _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
271           iter, index_size, index_stride, self.numel(), f
272         );
273       }
274     );
275   }
276 
277   template <typename func_t>
operator ()at::native::cuda_scatter_gather_base_kernel278   void operator()(
279     const Tensor& self, int64_t dim,
280     const Tensor& index, const Tensor& src,
281     const std::string& method_name,
282     const func_t& f
283   ) {
284     at::assert_no_internal_overlap(self);
285 
286     auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
287     auto self_strides = ensure_nonempty_vec(self.strides().vec());
288     auto src_strides = ensure_nonempty_vec(src.strides().vec());
289 
290     // restride self and src such that
291     // self.shape = src.shape = index.shape
292     //
293     // restride stride[dim] such that
294     // if (is_scatter_like) self.stride[dim] = 0
295     // else src.stride[dim] = 0
296     auto self_restrided = is_scatter_like ?
297         restride_dim(self, dim, index_sizes)
298       : self.as_strided(index_sizes, self_strides);
299     auto src_restrided = is_scatter_like ?
300         src.as_strided(index_sizes, src_strides)
301       : restride_dim(src, dim, index_sizes);
302 
303     auto iter = TensorIteratorConfig()
304       .set_check_mem_overlap(false)
305       .check_all_same_dtype(false)
306       .resize_outputs(false)
307       .add_output(self_restrided)
308       .add_const_input(src_restrided)
309       .add_const_input(index)
310       .build();
311 
312     auto self_dim_stride = ensure_nonempty_stride(self, dim);
313     auto self_dim_size = ensure_nonempty_size(self, dim);
314 
315     auto src_dim_stride = ensure_nonempty_stride(src, dim);
316     auto src_dim_size = ensure_nonempty_size(src, dim);
317 
318     auto index_size = is_scatter_like ? self_dim_size : src_dim_size;
319     auto index_stride = is_scatter_like ? self_dim_stride : src_dim_stride;
320 
321 
322     AT_DISPATCH_ALL_TYPES_AND2(
323       at::ScalarType::Half, at::ScalarType::BFloat16,
324       iter.dtype(),
325       "cuda_scatter_gather_base_kernel_func", [&] {
326         using dtype = typename std::conditional<cast_to_opaque,
327           OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
328 
329         _cuda_scatter_gather_internal_kernel<is_scatter_like, dtype>()(
330           iter, index_size, index_stride, self.numel(), f
331         );
332       }
333     );
334   }
335 }; // struct cuda_scatter_gather_base_kernel
336 
337 template <typename scalar_t>
338 struct _cuda_scatter_fill_internal_kernel {
339   template <typename func_t>
operator ()at::native::_cuda_scatter_fill_internal_kernel340   void operator()(
341     TensorIterator& iter,
342     scalar_t src_val,
343     int64_t index_size,
344     int64_t index_stride,
345     int64_t numel,  // Do not use `const` qualifier here as it may cause issue in cuda 11.6.x. See #75434, #75545
346     const func_t& f
347   ) {
348     if (!iter.can_use_32bit_indexing()) {
349       for (auto& sub_iter : iter.with_32bit_indexing()) {
350         _cuda_scatter_fill_internal_kernel<scalar_t>()(
351           sub_iter, src_val, index_size, index_stride, numel, f
352         );
353       }
354       return;
355     }
356 
357     char* self_ptr = (char*)iter.data_ptr(0);
358     char* index_ptr = (char*)iter.data_ptr(1);
359 
360     auto offset_calc = make_offset_calculator<2>(iter);
361     auto loop = [=]C10_DEVICE(int i) {
362       auto offsets = offset_calc.get(i);
363 
364       int64_t idx_dim = *(int64_t*)(index_ptr + offsets[1]);
365       CUDA_KERNEL_ASSERT(idx_dim >= 0 && idx_dim < index_size
366         && "index out of bounds"
367       );
368 
369       f(
370         (scalar_t*)(self_ptr + offsets[0]),
371         idx_dim * index_stride,
372         numel,
373         (scalar_t*)&src_val
374       );
375     };
376 
377     _launch_scatter_gather_kernel<num_threads(), thread_work_size()>(iter.numel(), loop);
378   }
379 }; // struct _cuda_scatter_fill_internal_kernel
380 
381 template <bool cast_to_opaque = true>
382 struct cuda_scatter_fill_base_kernel {
383   template <typename func_t>
operator ()at::native::cuda_scatter_fill_base_kernel384   void operator()(
385     const Tensor& self, int64_t dim,
386     const Tensor& index, Scalar src,
387     const std::string& method_name,
388     const func_t& f
389   ) {
390     at::assert_no_internal_overlap(self);
391 
392     auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
393 
394     // restride self such that
395     // self.shape = index.shape and
396     // self.stride[dim] = 0
397     auto self_restrided = restride_dim(self, dim, index_sizes);
398 
399     auto iter = TensorIteratorConfig()
400       .set_check_mem_overlap(false)
401       .check_all_same_dtype(false)
402       .resize_outputs(false)
403       .add_output(self_restrided)
404       .add_const_input(index)
405       .build();
406 
407     auto index_size = ensure_nonempty_size(self, dim);
408     auto index_stride = ensure_nonempty_stride(self, dim);
409 
410     AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
411       at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
412       iter.dtype(),
413       "cuda_scatter_fill_base_kernel_func", [&] {
414         using dtype = typename std::conditional<cast_to_opaque,
415           OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
416 
417         auto src_scalar_val = src.to<scalar_t>();
418         auto src_val = *(dtype*)&src_scalar_val;
419 
420         _cuda_scatter_fill_internal_kernel<dtype>()(
421           iter, src_val, index_size, index_stride, self.numel(), f
422         );
423       }
424     );
425   }
426 
operator ()at::native::cuda_scatter_fill_base_kernel427   void operator()(
428     const Tensor& self, int64_t dim,
429     const Tensor& index, Scalar src,
430     const std::string& method_name,
431     const ReduceMultiply& f
432   ) {
433     at::assert_no_internal_overlap(self);
434 
435     auto index_sizes = ensure_nonempty_vec(index.sizes().vec());
436 
437     // restride self such that
438     // self.shape = index.shape and
439     // self.stride[dim] = 0
440     auto self_restrided = restride_dim(self, dim, index_sizes);
441 
442     auto iter = TensorIteratorConfig()
443       .set_check_mem_overlap(false)
444       .check_all_same_dtype(false)
445       .resize_outputs(false)
446       .add_output(self_restrided)
447       .add_const_input(index)
448       .build();
449 
450     auto index_size = ensure_nonempty_size(self, dim);
451     auto index_stride = ensure_nonempty_stride(self, dim);
452 
453     AT_DISPATCH_ALL_TYPES_AND2(
454       at::ScalarType::Half, at::ScalarType::BFloat16,
455       iter.dtype(),
456       "cuda_scatter_fill_base_kernel_reduce_multiply", [&] {
457         using dtype = typename std::conditional<cast_to_opaque,
458           OpaqueType<sizeof(scalar_t)>, scalar_t>::type;
459 
460         auto src_scalar_val = src.to<scalar_t>();
461         auto src_val = *(dtype*)&src_scalar_val;
462 
463         _cuda_scatter_fill_internal_kernel<dtype>()(
464           iter, src_val, index_size, index_stride, self.numel(), f
465         );
466       }
467     );
468   }
469 }; // struct cuda_scatter_fill_base_kernel
470 
gather_cuda_kernel(const Tensor & result,const Tensor & self,int64_t dim,const Tensor & index)471 void gather_cuda_kernel(const Tensor& result, const Tensor& self, int64_t dim, const Tensor& index) {
472   cuda_scatter_gather_base_kernel</*is_scatter_like=*/false>()(
473     result, dim, index, self,
474     "gather_out_cuda", tensor_assign);
475 }
476 
scatter_cuda_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src)477 void scatter_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
478   // When indices are not unique, the behavior is non-deterministic
479   globalContext().alertNotDeterministic("scatter_cuda_");
480   cuda_scatter_gather_base_kernel<>()(
481     self, dim, index, src,
482     "scatter_cuda_", tensor_assign);
483 }
484 
scatter_fill_cuda_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Scalar & src)485 void scatter_fill_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Scalar& src) {
486   cuda_scatter_fill_base_kernel<>()(
487     self, dim, index, src,
488     "scatter_fill_cuda_", tensor_assign);
489 }
490 
scatter_add_cuda_kernel(const Tensor & self,int64_t dim,const Tensor & index,const Tensor & src)491 void scatter_add_cuda_kernel(const Tensor& self, int64_t dim, const Tensor& index, const Tensor& src) {
492   // See Note [Writing Nondeterministic Operations]
493   // Nondeterministic because of atomicAdd usage
494   globalContext().alertNotDeterministic("scatter_add_cuda_kernel");
495   cuda_scatter_gather_base_kernel</*is_scatter_like=*/true, /*cast_to_opaque=*/false>()(
496     self, dim, index, src,
497     "scatter_add_cuda_", reduce_add);
498 }
499 
scatter_reduce_cuda_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Tensor & src,const ReductionType & reduce)500 void scatter_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
501                                const Tensor& src, const ReductionType& reduce) {
502   // See Note [Writing Nondeterministic Operations]
503   // Nondeterministic because of atomicAdd/AtomicMul usage
504   globalContext().alertNotDeterministic("scatter_reduce_cuda_kernel");
505   switch (reduce) {
506   case ReductionType::SUM :
507     cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
508                                        "scatter_reduce_cuda_add_", reduce_add);
509     break;
510   case ReductionType::PROD :
511     cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
512                                        "scatter_reduce_cuda_multiply_", reduce_multiply);
513     break;
514   default :
515     break;
516   }
517 }
518 
scatter_reduce_two_cuda_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Tensor & src,const ReductionType & reduce)519 void scatter_reduce_two_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
520                                     const Tensor& src, const ReductionType& reduce) {
521   switch (reduce) {
522   case ReductionType::SUM :
523     globalContext().alertNotDeterministic("scatter_reduce_cuda_sum_");
524     cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
525             "scatter_reduce_cuda_sum_", reduce_add);
526     break;
527   case ReductionType::PROD :
528     globalContext().alertNotDeterministic("scatter_reduce_cuda_prod_");
529     cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
530             "scatter_reduce_cuda_prod_", reduce_multiply);
531     break;
532   case ReductionType::MAX :
533     cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
534             "scatter_reduce_cuda_amax_", reduce_maximum);
535     break;
536   case ReductionType::MIN :
537     cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
538             "scatter_reduce_cuda_amin_", reduce_minimum);
539     break;
540   case ReductionType::MEAN :
541     globalContext().alertNotDeterministic("scatter_reduce_cuda_mean_");
542     cuda_scatter_gather_base_kernel<true, false>()(self, dim, index, src,
543             "scatter_reduce_cuda_mean_", reduce_mean);
544     break;
545   }
546 }
547 
scatter_scalar_reduce_cuda_kernel(const Tensor & self,const int64_t dim,const Tensor & index,const Scalar & value,const ReductionType & reduce)548 void scatter_scalar_reduce_cuda_kernel(const Tensor& self, const int64_t dim, const Tensor& index,
549                                const Scalar& value, const ReductionType& reduce) {
550   switch (reduce) {
551   case ReductionType::SUM :
552     cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
553                                       "scatter_fill_cuda_add_", reduce_add);
554     break;
555   case ReductionType::PROD :
556     cuda_scatter_fill_base_kernel<false>()(self, dim, index, value,
557                                       "scatter_fill_cuda_multiply_", reduce_multiply);
558     break;
559   default :
560     break;
561   }
562 }
563 
564 
565 REGISTER_DISPATCH(gather_stub, &gather_cuda_kernel);
566 REGISTER_DISPATCH(scatter_stub, &scatter_cuda_kernel);
567 REGISTER_DISPATCH(scatter_fill_stub, &scatter_fill_cuda_kernel);
568 REGISTER_DISPATCH(scatter_add_stub, &scatter_add_cuda_kernel);
569 REGISTER_DISPATCH(scatter_reduce_stub, &scatter_reduce_cuda_kernel);
570 REGISTER_DISPATCH(scatter_scalar_reduce_stub, &scatter_scalar_reduce_cuda_kernel);
571 REGISTER_DISPATCH(scatter_reduce_two_stub, &scatter_reduce_two_cuda_kernel);
572 
573 } // namespace at::native
574