xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/SegmentReduce.cu (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/SegmentReduce.h>
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/NumericUtils.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/KernelUtils.h>
9 #include <ATen/cuda/cub.cuh>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/zeros.h>
16 #include <ATen/ops/cat.h>
17 #include <ATen/ops/cumsum.h>
18 #endif
19 
20 namespace at::native {
21 
22 namespace {
23 struct CustomMax {
24   template <typename OutputT>
25   __host__ __device__ __forceinline__ OutputT
operator ()at::native::__anonf893c5170111::CustomMax26   operator()(const OutputT& a, const OutputT& b) const {
27     if (at::_isnan(a)) {
28       return a;
29     } else if (at::_isnan(b)) {
30       return b;
31     }
32     return std::max<OutputT>(a, b);
33   }
34 };
35 
36 struct CustomSum {
37   template <typename OutputT>
38   __host__ __device__ __forceinline__ OutputT
operator ()at::native::__anonf893c5170111::CustomSum39   operator()(const OutputT& a, const OutputT& b) const {
40     return a + b;
41   }
42 };
43 
44 struct CustomProd {
45   template <typename OutputT>
46   __host__ __device__ __forceinline__ OutputT
operator ()at::native::__anonf893c5170111::CustomProd47   operator()(const OutputT& a, const OutputT& b) const {
48     return a * b;
49   }
50 };
51 
52 struct CustomMin {
53   template <typename OutputT>
54   __host__ __device__ __forceinline__ OutputT
operator ()at::native::__anonf893c5170111::CustomMin55   operator()(const OutputT& a, const OutputT& b) const {
56     if (at::_isnan(a)) {
57       return a;
58     } else if (at::_isnan(b)) {
59       return b;
60     }
61     return std::min<OutputT>(a, b);
62   }
63 };
64 
65 template <typename scalar_t, typename index_t>
post_sum_div_kernel(scalar_t * output_data,const index_t * lengths_data,const int64_t segment_count,bool is_initial_set,scalar_t initial)66 __global__ static void post_sum_div_kernel(
67     scalar_t* output_data,
68     const index_t* lengths_data,
69     const int64_t segment_count,
70     bool is_initial_set,
71     scalar_t initial) {
72   CUDA_KERNEL_LOOP(index, segment_count) {
73     CUDA_KERNEL_ASSERT(lengths_data[index] >= 0);
74     if (lengths_data[index] == 0) {
75       if (is_initial_set) {
76         output_data[index] = initial;
77       } else {
78         output_data[index] = NAN;
79       }
80     } else if (!at::_isnan(output_data[index])) {
81       output_data[index] = output_data[index] / lengths_data[index];
82     }
83   }
84 }
85 
86 template <typename scalar_t, typename index_t>
segment_reduce_forward_kernel(ReductionType reduction,scalar_t * output_data,const scalar_t * values_data,const index_t * lengths_data,const index_t * lengths_cumsum_data,const int64_t segment_count,const int64_t lengths_stride_axis,bool is_initial_set,scalar_t initial_value,const int64_t outer_offset,const int64_t inner_offset,const int64_t data_stride_axis,const int64_t data_size_axis,const int64_t output_stride_axis,const int64_t output_size_axis,const int64_t lengths_cumsum_stride_axis)87 __global__ void segment_reduce_forward_kernel(
88     ReductionType reduction,
89     scalar_t* output_data,
90     const scalar_t* values_data,
91     const index_t* lengths_data,
92     const index_t* lengths_cumsum_data,
93     const int64_t segment_count,
94     const int64_t lengths_stride_axis,
95     bool is_initial_set,
96     scalar_t initial_value,
97     const int64_t outer_offset,
98     const int64_t inner_offset,
99     const int64_t data_stride_axis,
100     const int64_t data_size_axis,
101     const int64_t output_stride_axis,
102     const int64_t output_size_axis,
103     const int64_t lengths_cumsum_stride_axis) {
104   int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
105   if (idx >= (outer_offset * segment_count * inner_offset)) {
106     return;
107   }
108   int64_t row_id = idx / inner_offset;
109   int64_t lane_id = idx % inner_offset;   // lane_id is the inner_idx
110   int64_t outer_idx = row_id / segment_count;
111   int64_t dim_idx = row_id % segment_count;
112 
113   int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
114   index_t offset_start = lengths_cumsum_data[offset_idx];
115   index_t offset_end = lengths_cumsum_data[offset_idx + 1];
116 
117   // ===== step2: apply reduction
118   for (index_t j = offset_start; j < offset_end; ++j) {
119     int64_t data_index = outer_idx * data_stride_axis * data_size_axis
120                          + j * data_stride_axis + lane_id;
121     const auto data = values_data[data_index];
122     // TODO: There is no need to branch with every element
123     if (reduction == ReductionType::MAX) {
124       initial_value =
125           at::_isnan(data) ? data : std::max<scalar_t>(initial_value, data);
126     } else if (
127         reduction == ReductionType::MEAN ||
128         reduction == ReductionType::SUM) {
129       initial_value = initial_value + data;
130     } else if (reduction == ReductionType::MIN) {
131       initial_value =
132           at::_isnan(data) ? data : std::min<scalar_t>(initial_value, data);
133     } else if (
134       reduction == ReductionType::PROD) {
135       initial_value = initial_value * data;
136     }
137   }
138 
139   // ===== step3: finalize reduction
140   int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
141   CUDA_KERNEL_ASSERT(lengths_data[lengths_idx] >= 0);
142   if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
143       reduction == ReductionType::MEAN) {
144     initial_value = static_cast<scalar_t>(NAN);
145   } else if (
146       reduction == ReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
147       !at::_isnan(initial_value)) {
148     initial_value = initial_value / lengths_data[lengths_idx];
149   }
150   int64_t output_index = outer_idx * output_stride_axis * output_size_axis
151                          + dim_idx * output_stride_axis + lane_id;
152   output_data[output_index] = initial_value;
153 }
154 
155 
156 template <typename scalar_t, typename index_t>
segment_reduce_backward_kernel(ReductionType reduction,scalar_t * grad_input_data,const scalar_t * grad_data,const scalar_t * output_data,const scalar_t * values_data,const index_t * lengths_data,const index_t * lengths_cumsum_data,const int64_t segment_count,const int64_t lengths_stride_axis,scalar_t initial_prod_value,const int64_t outer_offset,const int64_t inner_offset,const int64_t data_stride_axis,const int64_t data_size_axis,const int64_t output_stride_axis,const int64_t output_size_axis,const int64_t lengths_cumsum_stride_axis)157 __global__ void segment_reduce_backward_kernel(
158     ReductionType reduction,
159     scalar_t* grad_input_data,
160     const scalar_t* grad_data,
161     const scalar_t* output_data,
162     const scalar_t* values_data,
163     const index_t* lengths_data,
164     const index_t* lengths_cumsum_data,
165     const int64_t segment_count,
166     const int64_t lengths_stride_axis,
167     scalar_t initial_prod_value,
168     const int64_t outer_offset,
169     const int64_t inner_offset,
170     const int64_t data_stride_axis,
171     const int64_t data_size_axis,
172     const int64_t output_stride_axis,
173     const int64_t output_size_axis,
174     const int64_t lengths_cumsum_stride_axis) {
175   int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
176   if (idx >= (outer_offset * segment_count * inner_offset)) {
177     return;
178   }
179   int64_t row_id = idx / inner_offset;
180   int64_t lane_id = idx % inner_offset;  // lane_id is the inner_idx
181   int64_t outer_idx = row_id / segment_count;
182   int64_t dim_idx = row_id % segment_count;
183 
184   int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
185   auto segment_length = lengths_data[lengths_idx];
186   if (segment_length == 0) {
187     return;
188   }
189 
190   int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
191   index_t offset_start = lengths_cumsum_data[offset_idx];
192   index_t offset_end = lengths_cumsum_data[offset_idx + 1];
193 
194   int64_t output_index = outer_idx * output_stride_axis * output_size_axis
195                          + dim_idx * output_stride_axis + lane_id;
196 
197   if (reduction == ReductionType::MAX ||
198       reduction == ReductionType::MIN) {
199     int64_t counter = 0;
200     for (int64_t j = offset_start; j < offset_end; ++j) {
201       int64_t data_index = outer_idx * data_stride_axis * data_size_axis
202                            + j * data_stride_axis + lane_id;
203       if (at::_isnan(values_data[data_index]) ||
204           values_data[data_index] == output_data[output_index]) {
205         grad_input_data[data_index] = grad_data[output_index];
206         counter++;
207       }
208     }
209     // Average gradient based on number of maximum elements in the
210     // segment
211     if (counter < 2) {
212       return;
213     }
214     for (int64_t j = offset_start; j < offset_end; ++j) {
215       int64_t data_index = outer_idx * data_stride_axis * data_size_axis
216                            + j * data_stride_axis + lane_id;
217       if (grad_input_data[data_index] > 0) {
218         grad_input_data[data_index] =
219             grad_input_data[data_index] / counter;
220       }
221     }
222   } else if (reduction == ReductionType::MEAN) {
223     auto grad_val = grad_data[output_index] / segment_length;
224     for (int64_t j = offset_start; j < offset_end; ++j) {
225       int64_t data_index = outer_idx * data_stride_axis * data_size_axis
226                            + j * data_stride_axis + lane_id;
227       grad_input_data[data_index] = grad_val;
228     }
229   } else if (reduction == ReductionType::SUM) {
230     const auto& grad_val = grad_data[output_index];
231     for (int64_t j = offset_start; j < offset_end; ++j) {
232       int64_t data_index = outer_idx * data_stride_axis * data_size_axis
233                            + j * data_stride_axis + lane_id;
234       grad_input_data[data_index] = grad_val;
235     }
236   } else if (reduction == ReductionType::PROD) {
237     const auto& grad_val = grad_data[output_index] * output_data[output_index];
238     for (int64_t j = offset_start; j < offset_end; ++j) {
239       int64_t data_index = outer_idx * data_stride_axis * data_size_axis
240                            + j * data_stride_axis + lane_id;
241       if (at::_isnan(values_data[data_index]) ||
242           values_data[data_index] == 0) {
243         // explicitly compute exclusive prod
244         scalar_t exclusive_prod = initial_prod_value;
245         int64_t prod_idx;
246         for (int64_t k = offset_start; k < offset_end; ++k) {
247           if (k != j) {
248             prod_idx = outer_idx * data_stride_axis * data_size_axis
249                        + k * data_stride_axis + lane_id;
250             exclusive_prod *= values_data[prod_idx];
251           }
252         }
253         grad_input_data[data_index] = grad_data[output_index] * exclusive_prod;
254       } else {
255         grad_input_data[data_index] = grad_val / values_data[data_index];
256       }
257     }
258   }
259 }
260 } // namespace
261 
_segment_reduce_lengths_offsets_backward_cuda_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & lengths_or_offsets_contig,int64_t axis,const std::optional<Scalar> & initial,bool is_offsets_like)262 Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel(
263     const Tensor& grad_contig,
264     const Tensor& output_contig,
265     const Tensor& data_contig,
266     ReductionType reduction,
267     const Tensor& lengths_or_offsets_contig,
268     int64_t axis,
269     const std::optional<Scalar>& initial,
270     bool is_offsets_like) {
271   axis = lengths_or_offsets_contig.dim() - 1;
272   int64_t segment_count = is_offsets_like ?
273                           lengths_or_offsets_contig.size(axis) - 1 :
274                           lengths_or_offsets_contig.size(axis);
275   int64_t lengths_stride_axis = lengths_or_offsets_contig.stride(axis);
276   auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
277 
278   auto offsets = lengths_or_offsets_contig;
279   auto lengths = lengths_or_offsets_contig;
280   if (is_offsets_like) {
281     lengths = lengths.diff();
282   } else {
283     auto zeros_shape = offsets.sizes().vec();
284     zeros_shape[axis] = 1;
285     offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
286     offsets.cumsum_(axis);
287   }
288 
289   // outer_offset is the size of the outer dimensions of output (before axis)
290   // inner_offset is the size of the inner dimensions of output (after axis)
291   int64_t outer_offset = 1, inner_offset = 1;
292   for (int64_t d = 0; d < axis; d++) {
293     outer_offset *= output_contig.size(d);
294   }
295   for (int64_t d = axis + 1; d < output_contig.dim(); d++) {
296     inner_offset *= output_contig.size(d);
297   }
298 
299   constexpr int threads_per_block = 256;
300   int64_t num_blocks = (outer_offset * inner_offset * segment_count + threads_per_block - 1) / threads_per_block;
301 
302   num_blocks = std::max(num_blocks, (int64_t)1);
303 
304   auto data_stride_axis = data_contig.stride(axis);
305   auto data_size_axis = data_contig.size(axis);
306   auto output_stride_axis = output_contig.stride(axis);
307   auto output_size_axis = output_contig.size(axis);
308   auto offsets_stride_axis = offsets.stride(axis);
309 
310   AT_DISPATCH_INDEX_TYPES(
311       lengths_or_offsets_contig.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
312         const auto* lengths_data = lengths.const_data_ptr<index_t>();
313         auto* offsets_data = offsets.const_data_ptr<index_t>();
314 
315         // TODO: Switch to TensorIterator for better maintainablility and
316         // readability
317         AT_DISPATCH_FLOATING_TYPES_AND2(
318             kBFloat16,
319             kHalf,
320             data_contig.scalar_type(),
321             "_segment_reduce_cpu",
322             ([&]() {
323               auto* output_data = output_contig.const_data_ptr<scalar_t>();
324               auto* grad_data = grad_contig.const_data_ptr<scalar_t>();
325               auto* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
326               const auto* values_data = data_contig.const_data_ptr<scalar_t>();
327 
328               scalar_t initial_prod_value;
329               if (initial.has_value()) {
330                 initial_prod_value = initial.value().to<scalar_t>();
331               } else {
332                 initial_prod_value = 1;
333               }
334 
335               segment_reduce_backward_kernel<scalar_t>
336                   <<<num_blocks,
337                      threads_per_block,
338                      0,
339                      at::cuda::getCurrentCUDAStream()>>>(
340                       reduction,
341                       grad_input_data,
342                       grad_data,
343                       output_data,
344                       values_data,
345                       lengths_data,
346                       offsets_data,
347                       segment_count,
348                       lengths_stride_axis,
349                       initial_prod_value,
350                       outer_offset,
351                       inner_offset,
352                       data_stride_axis,
353                       data_size_axis,
354                       output_stride_axis,
355                       output_size_axis,
356                       offsets_stride_axis
357                     );
358               C10_CUDA_KERNEL_LAUNCH_CHECK();
359             }));
360       }));
361   return grad_input;
362 }
363 
_segment_reduce_lengths_backward_cuda_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & lengths_contig,int64_t axis,const std::optional<Scalar> & initial)364 Tensor _segment_reduce_lengths_backward_cuda_kernel(
365   const Tensor& grad_contig,
366   const Tensor& output_contig,
367   const Tensor& data_contig,
368   ReductionType reduction,
369   const Tensor& lengths_contig,
370   int64_t axis,
371   const std::optional<Scalar>& initial) {
372   return _segment_reduce_lengths_offsets_backward_cuda_kernel(
373     grad_contig, output_contig, data_contig, reduction, lengths_contig, axis, initial, /*is_offsets_like=*/false);
374 }
375 
_segment_reduce_offsets_backward_cuda_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & offsets_contig,int64_t axis,const std::optional<Scalar> & initial)376 Tensor _segment_reduce_offsets_backward_cuda_kernel(
377   const Tensor& grad_contig,
378   const Tensor& output_contig,
379   const Tensor& data_contig,
380   ReductionType reduction,
381   const Tensor& offsets_contig,
382   int64_t axis,
383   const std::optional<Scalar>& initial) {
384   return _segment_reduce_lengths_offsets_backward_cuda_kernel(
385     grad_contig, output_contig, data_contig, reduction, offsets_contig, axis, initial, /*is_offsets_like=*/true);
386 }
387 
_segment_reduce_lengths_offsets_cuda_kernel(ReductionType reduction,const Tensor & data,const Tensor & lengths_or_offsets,int64_t axis,const std::optional<Scalar> & initial,bool is_offsets_like)388 Tensor _segment_reduce_lengths_offsets_cuda_kernel(
389   ReductionType reduction,
390   const Tensor& data,
391   const Tensor& lengths_or_offsets,
392   int64_t axis,
393   const std::optional<Scalar>& initial,
394   bool is_offsets_like) {
395   // data and lengths_or_offsets should be contiguous from the call to .contiguous in segment_reduce_kernel
396   TORCH_CHECK(data.is_contiguous());
397   TORCH_CHECK(lengths_or_offsets.is_contiguous());
398   axis = lengths_or_offsets.dim() - 1;
399   int64_t segment_count = is_offsets_like ? lengths_or_offsets.size(axis) - 1 : lengths_or_offsets.size(axis);
400   int64_t lengths_stride_axis = lengths_or_offsets.stride(axis);
401   auto output_shape = data.sizes().vec();
402   output_shape[axis] = segment_count;
403   auto output = at::empty(output_shape, data.options());
404 
405 
406   auto offsets = lengths_or_offsets;
407   auto lengths = lengths_or_offsets;
408   if (is_offsets_like) {
409     lengths = lengths.diff();
410   } else {
411     auto zeros_shape = offsets.sizes().vec();
412     zeros_shape[axis] = 1;
413     offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
414     offsets.cumsum_(axis);
415   }
416 
417   // outer_offset is the size of the outer dimensions of output (before axis)
418   // inner_offset is the size of the inner dimensions of output (after axis)
419   int64_t outer_offset = 1, inner_offset = 1;
420   for (int64_t d = 0; d < axis; d++) {
421     outer_offset *= output.size(d);
422   }
423   for (int64_t d = axis + 1; d < output.dim(); d++) {
424     inner_offset *= output.size(d);
425   }
426 
427   constexpr int threads_per_block = 256;
428   // segment_count * stride_count is just output.numel() ?
429   int64_t num_blocks = (output.numel() + threads_per_block - 1) / threads_per_block;
430 
431   num_blocks = std::max(num_blocks, (int64_t)1);
432 
433   auto data_stride_axis = data.stride(axis);
434   auto data_size_axis = data.size(axis);
435   auto output_stride_axis = output.stride(axis);
436   auto output_size_axis = output.size(axis);
437   auto offsets_stride_axis = offsets.stride(axis);
438 
439   AT_DISPATCH_INDEX_TYPES(
440       lengths_or_offsets.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
441         auto* offsets_data_ptr = offsets.const_data_ptr<index_t>();
442         auto* lengths_data_ptr = lengths.const_data_ptr<index_t>();
443         AT_DISPATCH_FLOATING_TYPES_AND2(
444             at::ScalarType::Half,
445             at::ScalarType::BFloat16,
446             data.scalar_type(),
447             "segment_reduce_cuda",
448             [&]() {
449               auto* data_data_ptr = data.const_data_ptr<scalar_t>();
450               auto* output_data_ptr = output.mutable_data_ptr<scalar_t>();
451 
452               // initialize starting value
453               scalar_t initial_value = 0;
454               if (initial.has_value()) {
455                 initial_value = initial.value().to<scalar_t>();
456               } else if (reduction == ReductionType::MAX) {
457                 initial_value = -std::numeric_limits<scalar_t>::infinity();
458               } else if (
459                   reduction == ReductionType::MEAN ||
460                   reduction == ReductionType::SUM) {
461                 initial_value = 0;
462               } else if (reduction == ReductionType::MIN) {
463                 initial_value = std::numeric_limits<scalar_t>::infinity();
464               } else if (reduction == ReductionType::PROD) {
465                 initial_value = 1;
466               }
467 
468               if (output_shape.size() > 1) {
469                 segment_reduce_forward_kernel<scalar_t>
470                     <<<num_blocks,
471                        threads_per_block,
472                        0,
473                        at::cuda::getCurrentCUDAStream()>>>(
474                         reduction,
475                         output_data_ptr,
476                         data_data_ptr,
477                         lengths_data_ptr,
478                         offsets_data_ptr,
479                         segment_count,
480                         lengths_stride_axis,
481                         initial.has_value(),
482                         initial_value,
483                         outer_offset,
484                         inner_offset,
485                         data_stride_axis,
486                         data_size_axis,
487                         output_stride_axis,
488                         output_size_axis,
489                         offsets_stride_axis
490                       );
491                 C10_CUDA_KERNEL_LAUNCH_CHECK();
492               } else {
493                 if (reduction == ReductionType::MAX) {
494                   CustomMax max_op{};
495                   CUB_WRAPPER(
496                       cub::DeviceSegmentedReduce::Reduce,
497                       data_data_ptr,
498                       output_data_ptr,
499                       segment_count,
500                       offsets_data_ptr,
501                       offsets_data_ptr + 1,
502                       max_op,
503                       initial_value,
504                       at::cuda::getCurrentCUDAStream());
505                 } else if (reduction == ReductionType::MEAN) {
506                   CustomSum sum_op{};
507                   CUB_WRAPPER(
508                       cub::DeviceSegmentedReduce::Reduce,
509                       data_data_ptr,
510                       output_data_ptr,
511                       segment_count,
512                       offsets_data_ptr,
513                       offsets_data_ptr + 1,
514                       sum_op,
515                       initial_value,
516                       at::cuda::getCurrentCUDAStream());
517 
518                   post_sum_div_kernel<scalar_t>
519                       <<<num_blocks,
520                          threads_per_block,
521                          0,
522                          at::cuda::getCurrentCUDAStream()>>>(
523                           output_data_ptr,
524                           lengths_data_ptr,
525                           segment_count,
526                           initial.has_value(),
527                           initial_value);
528                   C10_CUDA_KERNEL_LAUNCH_CHECK();
529                 } else if (reduction == ReductionType::MIN) {
530                   CustomMin min_op{};
531                   CUB_WRAPPER(
532                       cub::DeviceSegmentedReduce::Reduce,
533                       data_data_ptr,
534                       output_data_ptr,
535                       segment_count,
536                       offsets_data_ptr,
537                       offsets_data_ptr + 1,
538                       min_op,
539                       initial_value,
540                       at::cuda::getCurrentCUDAStream());
541                 } else if (reduction == ReductionType::SUM) {
542                   CustomSum sum_op{};
543                   CUB_WRAPPER(
544                       cub::DeviceSegmentedReduce::Reduce,
545                       data_data_ptr,
546                       output_data_ptr,
547                       segment_count,
548                       offsets_data_ptr,
549                       offsets_data_ptr + 1,
550                       sum_op,
551                       initial_value,
552                       at::cuda::getCurrentCUDAStream());
553                 } else if (reduction == ReductionType::PROD) {
554                   CustomProd prod_op{};
555                   CUB_WRAPPER(
556                       cub::DeviceSegmentedReduce::Reduce,
557                       data_data_ptr,
558                       output_data_ptr,
559                       segment_count,
560                       offsets_data_ptr,
561                       offsets_data_ptr + 1,
562                       prod_op,
563                       initial_value,
564                       at::cuda::getCurrentCUDAStream());
565                 }
566               }
567             });
568       }));
569 
570   return output;
571 }
572 
_segment_reduce_lengths_cuda_kernel(ReductionType reduction,const Tensor & data,const Tensor & lengths,int64_t axis,const std::optional<Scalar> & initial)573 Tensor _segment_reduce_lengths_cuda_kernel(
574   ReductionType reduction,
575   const Tensor& data,
576   const Tensor& lengths,
577   int64_t axis,
578   const std::optional<Scalar>& initial) {
579   return _segment_reduce_lengths_offsets_cuda_kernel(
580     reduction, data, lengths, axis, initial, /*is_offsets_like=*/false);
581 }
582 
_segment_reduce_offsets_cuda_kernel(ReductionType reduction,const Tensor & data,const Tensor & offsets,int64_t axis,const std::optional<Scalar> & initial)583 Tensor _segment_reduce_offsets_cuda_kernel(
584   ReductionType reduction,
585   const Tensor& data,
586   const Tensor& offsets,
587   int64_t axis,
588   const std::optional<Scalar>& initial) {
589   return _segment_reduce_lengths_offsets_cuda_kernel(
590     reduction, data, offsets, axis, initial, /*is_offsets_like=*/true);
591 }
592 
593 REGISTER_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cuda_kernel);
594 REGISTER_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cuda_kernel);
595 REGISTER_DISPATCH(
596     _segment_reduce_lengths_backward_stub,
597     &_segment_reduce_lengths_backward_cuda_kernel);
598 REGISTER_DISPATCH(
599   _segment_reduce_offsets_backward_stub,
600   &_segment_reduce_offsets_backward_cuda_kernel);
601 
602 } // namespace at::native
603