1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/SegmentReduce.h>
3
4 #include <ATen/core/Tensor.h>
5 #include <ATen/Dispatch.h>
6 #include <ATen/NumericUtils.h>
7 #include <ATen/cuda/CUDAContext.h>
8 #include <ATen/cuda/detail/KernelUtils.h>
9 #include <ATen/cuda/cub.cuh>
10
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #else
14 #include <ATen/ops/empty.h>
15 #include <ATen/ops/zeros.h>
16 #include <ATen/ops/cat.h>
17 #include <ATen/ops/cumsum.h>
18 #endif
19
20 namespace at::native {
21
22 namespace {
23 struct CustomMax {
24 template <typename OutputT>
25 __host__ __device__ __forceinline__ OutputT
operator ()at::native::__anonf893c5170111::CustomMax26 operator()(const OutputT& a, const OutputT& b) const {
27 if (at::_isnan(a)) {
28 return a;
29 } else if (at::_isnan(b)) {
30 return b;
31 }
32 return std::max<OutputT>(a, b);
33 }
34 };
35
36 struct CustomSum {
37 template <typename OutputT>
38 __host__ __device__ __forceinline__ OutputT
operator ()at::native::__anonf893c5170111::CustomSum39 operator()(const OutputT& a, const OutputT& b) const {
40 return a + b;
41 }
42 };
43
44 struct CustomProd {
45 template <typename OutputT>
46 __host__ __device__ __forceinline__ OutputT
operator ()at::native::__anonf893c5170111::CustomProd47 operator()(const OutputT& a, const OutputT& b) const {
48 return a * b;
49 }
50 };
51
52 struct CustomMin {
53 template <typename OutputT>
54 __host__ __device__ __forceinline__ OutputT
operator ()at::native::__anonf893c5170111::CustomMin55 operator()(const OutputT& a, const OutputT& b) const {
56 if (at::_isnan(a)) {
57 return a;
58 } else if (at::_isnan(b)) {
59 return b;
60 }
61 return std::min<OutputT>(a, b);
62 }
63 };
64
65 template <typename scalar_t, typename index_t>
post_sum_div_kernel(scalar_t * output_data,const index_t * lengths_data,const int64_t segment_count,bool is_initial_set,scalar_t initial)66 __global__ static void post_sum_div_kernel(
67 scalar_t* output_data,
68 const index_t* lengths_data,
69 const int64_t segment_count,
70 bool is_initial_set,
71 scalar_t initial) {
72 CUDA_KERNEL_LOOP(index, segment_count) {
73 CUDA_KERNEL_ASSERT(lengths_data[index] >= 0);
74 if (lengths_data[index] == 0) {
75 if (is_initial_set) {
76 output_data[index] = initial;
77 } else {
78 output_data[index] = NAN;
79 }
80 } else if (!at::_isnan(output_data[index])) {
81 output_data[index] = output_data[index] / lengths_data[index];
82 }
83 }
84 }
85
86 template <typename scalar_t, typename index_t>
segment_reduce_forward_kernel(ReductionType reduction,scalar_t * output_data,const scalar_t * values_data,const index_t * lengths_data,const index_t * lengths_cumsum_data,const int64_t segment_count,const int64_t lengths_stride_axis,bool is_initial_set,scalar_t initial_value,const int64_t outer_offset,const int64_t inner_offset,const int64_t data_stride_axis,const int64_t data_size_axis,const int64_t output_stride_axis,const int64_t output_size_axis,const int64_t lengths_cumsum_stride_axis)87 __global__ void segment_reduce_forward_kernel(
88 ReductionType reduction,
89 scalar_t* output_data,
90 const scalar_t* values_data,
91 const index_t* lengths_data,
92 const index_t* lengths_cumsum_data,
93 const int64_t segment_count,
94 const int64_t lengths_stride_axis,
95 bool is_initial_set,
96 scalar_t initial_value,
97 const int64_t outer_offset,
98 const int64_t inner_offset,
99 const int64_t data_stride_axis,
100 const int64_t data_size_axis,
101 const int64_t output_stride_axis,
102 const int64_t output_size_axis,
103 const int64_t lengths_cumsum_stride_axis) {
104 int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
105 if (idx >= (outer_offset * segment_count * inner_offset)) {
106 return;
107 }
108 int64_t row_id = idx / inner_offset;
109 int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
110 int64_t outer_idx = row_id / segment_count;
111 int64_t dim_idx = row_id % segment_count;
112
113 int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
114 index_t offset_start = lengths_cumsum_data[offset_idx];
115 index_t offset_end = lengths_cumsum_data[offset_idx + 1];
116
117 // ===== step2: apply reduction
118 for (index_t j = offset_start; j < offset_end; ++j) {
119 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
120 + j * data_stride_axis + lane_id;
121 const auto data = values_data[data_index];
122 // TODO: There is no need to branch with every element
123 if (reduction == ReductionType::MAX) {
124 initial_value =
125 at::_isnan(data) ? data : std::max<scalar_t>(initial_value, data);
126 } else if (
127 reduction == ReductionType::MEAN ||
128 reduction == ReductionType::SUM) {
129 initial_value = initial_value + data;
130 } else if (reduction == ReductionType::MIN) {
131 initial_value =
132 at::_isnan(data) ? data : std::min<scalar_t>(initial_value, data);
133 } else if (
134 reduction == ReductionType::PROD) {
135 initial_value = initial_value * data;
136 }
137 }
138
139 // ===== step3: finalize reduction
140 int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
141 CUDA_KERNEL_ASSERT(lengths_data[lengths_idx] >= 0);
142 if (lengths_data[lengths_idx] == 0 && !is_initial_set &&
143 reduction == ReductionType::MEAN) {
144 initial_value = static_cast<scalar_t>(NAN);
145 } else if (
146 reduction == ReductionType::MEAN && lengths_data[lengths_idx] > 0 &&
147 !at::_isnan(initial_value)) {
148 initial_value = initial_value / lengths_data[lengths_idx];
149 }
150 int64_t output_index = outer_idx * output_stride_axis * output_size_axis
151 + dim_idx * output_stride_axis + lane_id;
152 output_data[output_index] = initial_value;
153 }
154
155
156 template <typename scalar_t, typename index_t>
segment_reduce_backward_kernel(ReductionType reduction,scalar_t * grad_input_data,const scalar_t * grad_data,const scalar_t * output_data,const scalar_t * values_data,const index_t * lengths_data,const index_t * lengths_cumsum_data,const int64_t segment_count,const int64_t lengths_stride_axis,scalar_t initial_prod_value,const int64_t outer_offset,const int64_t inner_offset,const int64_t data_stride_axis,const int64_t data_size_axis,const int64_t output_stride_axis,const int64_t output_size_axis,const int64_t lengths_cumsum_stride_axis)157 __global__ void segment_reduce_backward_kernel(
158 ReductionType reduction,
159 scalar_t* grad_input_data,
160 const scalar_t* grad_data,
161 const scalar_t* output_data,
162 const scalar_t* values_data,
163 const index_t* lengths_data,
164 const index_t* lengths_cumsum_data,
165 const int64_t segment_count,
166 const int64_t lengths_stride_axis,
167 scalar_t initial_prod_value,
168 const int64_t outer_offset,
169 const int64_t inner_offset,
170 const int64_t data_stride_axis,
171 const int64_t data_size_axis,
172 const int64_t output_stride_axis,
173 const int64_t output_size_axis,
174 const int64_t lengths_cumsum_stride_axis) {
175 int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
176 if (idx >= (outer_offset * segment_count * inner_offset)) {
177 return;
178 }
179 int64_t row_id = idx / inner_offset;
180 int64_t lane_id = idx % inner_offset; // lane_id is the inner_idx
181 int64_t outer_idx = row_id / segment_count;
182 int64_t dim_idx = row_id % segment_count;
183
184 int64_t lengths_idx = outer_idx * lengths_stride_axis * segment_count + dim_idx;
185 auto segment_length = lengths_data[lengths_idx];
186 if (segment_length == 0) {
187 return;
188 }
189
190 int64_t offset_idx = outer_idx * lengths_cumsum_stride_axis * (segment_count + 1) + dim_idx;
191 index_t offset_start = lengths_cumsum_data[offset_idx];
192 index_t offset_end = lengths_cumsum_data[offset_idx + 1];
193
194 int64_t output_index = outer_idx * output_stride_axis * output_size_axis
195 + dim_idx * output_stride_axis + lane_id;
196
197 if (reduction == ReductionType::MAX ||
198 reduction == ReductionType::MIN) {
199 int64_t counter = 0;
200 for (int64_t j = offset_start; j < offset_end; ++j) {
201 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
202 + j * data_stride_axis + lane_id;
203 if (at::_isnan(values_data[data_index]) ||
204 values_data[data_index] == output_data[output_index]) {
205 grad_input_data[data_index] = grad_data[output_index];
206 counter++;
207 }
208 }
209 // Average gradient based on number of maximum elements in the
210 // segment
211 if (counter < 2) {
212 return;
213 }
214 for (int64_t j = offset_start; j < offset_end; ++j) {
215 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
216 + j * data_stride_axis + lane_id;
217 if (grad_input_data[data_index] > 0) {
218 grad_input_data[data_index] =
219 grad_input_data[data_index] / counter;
220 }
221 }
222 } else if (reduction == ReductionType::MEAN) {
223 auto grad_val = grad_data[output_index] / segment_length;
224 for (int64_t j = offset_start; j < offset_end; ++j) {
225 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
226 + j * data_stride_axis + lane_id;
227 grad_input_data[data_index] = grad_val;
228 }
229 } else if (reduction == ReductionType::SUM) {
230 const auto& grad_val = grad_data[output_index];
231 for (int64_t j = offset_start; j < offset_end; ++j) {
232 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
233 + j * data_stride_axis + lane_id;
234 grad_input_data[data_index] = grad_val;
235 }
236 } else if (reduction == ReductionType::PROD) {
237 const auto& grad_val = grad_data[output_index] * output_data[output_index];
238 for (int64_t j = offset_start; j < offset_end; ++j) {
239 int64_t data_index = outer_idx * data_stride_axis * data_size_axis
240 + j * data_stride_axis + lane_id;
241 if (at::_isnan(values_data[data_index]) ||
242 values_data[data_index] == 0) {
243 // explicitly compute exclusive prod
244 scalar_t exclusive_prod = initial_prod_value;
245 int64_t prod_idx;
246 for (int64_t k = offset_start; k < offset_end; ++k) {
247 if (k != j) {
248 prod_idx = outer_idx * data_stride_axis * data_size_axis
249 + k * data_stride_axis + lane_id;
250 exclusive_prod *= values_data[prod_idx];
251 }
252 }
253 grad_input_data[data_index] = grad_data[output_index] * exclusive_prod;
254 } else {
255 grad_input_data[data_index] = grad_val / values_data[data_index];
256 }
257 }
258 }
259 }
260 } // namespace
261
_segment_reduce_lengths_offsets_backward_cuda_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & lengths_or_offsets_contig,int64_t axis,const std::optional<Scalar> & initial,bool is_offsets_like)262 Tensor _segment_reduce_lengths_offsets_backward_cuda_kernel(
263 const Tensor& grad_contig,
264 const Tensor& output_contig,
265 const Tensor& data_contig,
266 ReductionType reduction,
267 const Tensor& lengths_or_offsets_contig,
268 int64_t axis,
269 const std::optional<Scalar>& initial,
270 bool is_offsets_like) {
271 axis = lengths_or_offsets_contig.dim() - 1;
272 int64_t segment_count = is_offsets_like ?
273 lengths_or_offsets_contig.size(axis) - 1 :
274 lengths_or_offsets_contig.size(axis);
275 int64_t lengths_stride_axis = lengths_or_offsets_contig.stride(axis);
276 auto grad_input = at::zeros({data_contig.sizes()}, grad_contig.options());
277
278 auto offsets = lengths_or_offsets_contig;
279 auto lengths = lengths_or_offsets_contig;
280 if (is_offsets_like) {
281 lengths = lengths.diff();
282 } else {
283 auto zeros_shape = offsets.sizes().vec();
284 zeros_shape[axis] = 1;
285 offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
286 offsets.cumsum_(axis);
287 }
288
289 // outer_offset is the size of the outer dimensions of output (before axis)
290 // inner_offset is the size of the inner dimensions of output (after axis)
291 int64_t outer_offset = 1, inner_offset = 1;
292 for (int64_t d = 0; d < axis; d++) {
293 outer_offset *= output_contig.size(d);
294 }
295 for (int64_t d = axis + 1; d < output_contig.dim(); d++) {
296 inner_offset *= output_contig.size(d);
297 }
298
299 constexpr int threads_per_block = 256;
300 int64_t num_blocks = (outer_offset * inner_offset * segment_count + threads_per_block - 1) / threads_per_block;
301
302 num_blocks = std::max(num_blocks, (int64_t)1);
303
304 auto data_stride_axis = data_contig.stride(axis);
305 auto data_size_axis = data_contig.size(axis);
306 auto output_stride_axis = output_contig.stride(axis);
307 auto output_size_axis = output_contig.size(axis);
308 auto offsets_stride_axis = offsets.stride(axis);
309
310 AT_DISPATCH_INDEX_TYPES(
311 lengths_or_offsets_contig.scalar_type(), "_segment_reduce_cuda_lengths_offsets_backward_kernel1", ([&] {
312 const auto* lengths_data = lengths.const_data_ptr<index_t>();
313 auto* offsets_data = offsets.const_data_ptr<index_t>();
314
315 // TODO: Switch to TensorIterator for better maintainablility and
316 // readability
317 AT_DISPATCH_FLOATING_TYPES_AND2(
318 kBFloat16,
319 kHalf,
320 data_contig.scalar_type(),
321 "_segment_reduce_cpu",
322 ([&]() {
323 auto* output_data = output_contig.const_data_ptr<scalar_t>();
324 auto* grad_data = grad_contig.const_data_ptr<scalar_t>();
325 auto* grad_input_data = grad_input.mutable_data_ptr<scalar_t>();
326 const auto* values_data = data_contig.const_data_ptr<scalar_t>();
327
328 scalar_t initial_prod_value;
329 if (initial.has_value()) {
330 initial_prod_value = initial.value().to<scalar_t>();
331 } else {
332 initial_prod_value = 1;
333 }
334
335 segment_reduce_backward_kernel<scalar_t>
336 <<<num_blocks,
337 threads_per_block,
338 0,
339 at::cuda::getCurrentCUDAStream()>>>(
340 reduction,
341 grad_input_data,
342 grad_data,
343 output_data,
344 values_data,
345 lengths_data,
346 offsets_data,
347 segment_count,
348 lengths_stride_axis,
349 initial_prod_value,
350 outer_offset,
351 inner_offset,
352 data_stride_axis,
353 data_size_axis,
354 output_stride_axis,
355 output_size_axis,
356 offsets_stride_axis
357 );
358 C10_CUDA_KERNEL_LAUNCH_CHECK();
359 }));
360 }));
361 return grad_input;
362 }
363
_segment_reduce_lengths_backward_cuda_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & lengths_contig,int64_t axis,const std::optional<Scalar> & initial)364 Tensor _segment_reduce_lengths_backward_cuda_kernel(
365 const Tensor& grad_contig,
366 const Tensor& output_contig,
367 const Tensor& data_contig,
368 ReductionType reduction,
369 const Tensor& lengths_contig,
370 int64_t axis,
371 const std::optional<Scalar>& initial) {
372 return _segment_reduce_lengths_offsets_backward_cuda_kernel(
373 grad_contig, output_contig, data_contig, reduction, lengths_contig, axis, initial, /*is_offsets_like=*/false);
374 }
375
_segment_reduce_offsets_backward_cuda_kernel(const Tensor & grad_contig,const Tensor & output_contig,const Tensor & data_contig,ReductionType reduction,const Tensor & offsets_contig,int64_t axis,const std::optional<Scalar> & initial)376 Tensor _segment_reduce_offsets_backward_cuda_kernel(
377 const Tensor& grad_contig,
378 const Tensor& output_contig,
379 const Tensor& data_contig,
380 ReductionType reduction,
381 const Tensor& offsets_contig,
382 int64_t axis,
383 const std::optional<Scalar>& initial) {
384 return _segment_reduce_lengths_offsets_backward_cuda_kernel(
385 grad_contig, output_contig, data_contig, reduction, offsets_contig, axis, initial, /*is_offsets_like=*/true);
386 }
387
_segment_reduce_lengths_offsets_cuda_kernel(ReductionType reduction,const Tensor & data,const Tensor & lengths_or_offsets,int64_t axis,const std::optional<Scalar> & initial,bool is_offsets_like)388 Tensor _segment_reduce_lengths_offsets_cuda_kernel(
389 ReductionType reduction,
390 const Tensor& data,
391 const Tensor& lengths_or_offsets,
392 int64_t axis,
393 const std::optional<Scalar>& initial,
394 bool is_offsets_like) {
395 // data and lengths_or_offsets should be contiguous from the call to .contiguous in segment_reduce_kernel
396 TORCH_CHECK(data.is_contiguous());
397 TORCH_CHECK(lengths_or_offsets.is_contiguous());
398 axis = lengths_or_offsets.dim() - 1;
399 int64_t segment_count = is_offsets_like ? lengths_or_offsets.size(axis) - 1 : lengths_or_offsets.size(axis);
400 int64_t lengths_stride_axis = lengths_or_offsets.stride(axis);
401 auto output_shape = data.sizes().vec();
402 output_shape[axis] = segment_count;
403 auto output = at::empty(output_shape, data.options());
404
405
406 auto offsets = lengths_or_offsets;
407 auto lengths = lengths_or_offsets;
408 if (is_offsets_like) {
409 lengths = lengths.diff();
410 } else {
411 auto zeros_shape = offsets.sizes().vec();
412 zeros_shape[axis] = 1;
413 offsets = at::cat({at::zeros(zeros_shape, offsets.options()), offsets}, axis);
414 offsets.cumsum_(axis);
415 }
416
417 // outer_offset is the size of the outer dimensions of output (before axis)
418 // inner_offset is the size of the inner dimensions of output (after axis)
419 int64_t outer_offset = 1, inner_offset = 1;
420 for (int64_t d = 0; d < axis; d++) {
421 outer_offset *= output.size(d);
422 }
423 for (int64_t d = axis + 1; d < output.dim(); d++) {
424 inner_offset *= output.size(d);
425 }
426
427 constexpr int threads_per_block = 256;
428 // segment_count * stride_count is just output.numel() ?
429 int64_t num_blocks = (output.numel() + threads_per_block - 1) / threads_per_block;
430
431 num_blocks = std::max(num_blocks, (int64_t)1);
432
433 auto data_stride_axis = data.stride(axis);
434 auto data_size_axis = data.size(axis);
435 auto output_stride_axis = output.stride(axis);
436 auto output_size_axis = output.size(axis);
437 auto offsets_stride_axis = offsets.stride(axis);
438
439 AT_DISPATCH_INDEX_TYPES(
440 lengths_or_offsets.scalar_type(), "_segment_reduce_cuda_kernel1", ([&] {
441 auto* offsets_data_ptr = offsets.const_data_ptr<index_t>();
442 auto* lengths_data_ptr = lengths.const_data_ptr<index_t>();
443 AT_DISPATCH_FLOATING_TYPES_AND2(
444 at::ScalarType::Half,
445 at::ScalarType::BFloat16,
446 data.scalar_type(),
447 "segment_reduce_cuda",
448 [&]() {
449 auto* data_data_ptr = data.const_data_ptr<scalar_t>();
450 auto* output_data_ptr = output.mutable_data_ptr<scalar_t>();
451
452 // initialize starting value
453 scalar_t initial_value = 0;
454 if (initial.has_value()) {
455 initial_value = initial.value().to<scalar_t>();
456 } else if (reduction == ReductionType::MAX) {
457 initial_value = -std::numeric_limits<scalar_t>::infinity();
458 } else if (
459 reduction == ReductionType::MEAN ||
460 reduction == ReductionType::SUM) {
461 initial_value = 0;
462 } else if (reduction == ReductionType::MIN) {
463 initial_value = std::numeric_limits<scalar_t>::infinity();
464 } else if (reduction == ReductionType::PROD) {
465 initial_value = 1;
466 }
467
468 if (output_shape.size() > 1) {
469 segment_reduce_forward_kernel<scalar_t>
470 <<<num_blocks,
471 threads_per_block,
472 0,
473 at::cuda::getCurrentCUDAStream()>>>(
474 reduction,
475 output_data_ptr,
476 data_data_ptr,
477 lengths_data_ptr,
478 offsets_data_ptr,
479 segment_count,
480 lengths_stride_axis,
481 initial.has_value(),
482 initial_value,
483 outer_offset,
484 inner_offset,
485 data_stride_axis,
486 data_size_axis,
487 output_stride_axis,
488 output_size_axis,
489 offsets_stride_axis
490 );
491 C10_CUDA_KERNEL_LAUNCH_CHECK();
492 } else {
493 if (reduction == ReductionType::MAX) {
494 CustomMax max_op{};
495 CUB_WRAPPER(
496 cub::DeviceSegmentedReduce::Reduce,
497 data_data_ptr,
498 output_data_ptr,
499 segment_count,
500 offsets_data_ptr,
501 offsets_data_ptr + 1,
502 max_op,
503 initial_value,
504 at::cuda::getCurrentCUDAStream());
505 } else if (reduction == ReductionType::MEAN) {
506 CustomSum sum_op{};
507 CUB_WRAPPER(
508 cub::DeviceSegmentedReduce::Reduce,
509 data_data_ptr,
510 output_data_ptr,
511 segment_count,
512 offsets_data_ptr,
513 offsets_data_ptr + 1,
514 sum_op,
515 initial_value,
516 at::cuda::getCurrentCUDAStream());
517
518 post_sum_div_kernel<scalar_t>
519 <<<num_blocks,
520 threads_per_block,
521 0,
522 at::cuda::getCurrentCUDAStream()>>>(
523 output_data_ptr,
524 lengths_data_ptr,
525 segment_count,
526 initial.has_value(),
527 initial_value);
528 C10_CUDA_KERNEL_LAUNCH_CHECK();
529 } else if (reduction == ReductionType::MIN) {
530 CustomMin min_op{};
531 CUB_WRAPPER(
532 cub::DeviceSegmentedReduce::Reduce,
533 data_data_ptr,
534 output_data_ptr,
535 segment_count,
536 offsets_data_ptr,
537 offsets_data_ptr + 1,
538 min_op,
539 initial_value,
540 at::cuda::getCurrentCUDAStream());
541 } else if (reduction == ReductionType::SUM) {
542 CustomSum sum_op{};
543 CUB_WRAPPER(
544 cub::DeviceSegmentedReduce::Reduce,
545 data_data_ptr,
546 output_data_ptr,
547 segment_count,
548 offsets_data_ptr,
549 offsets_data_ptr + 1,
550 sum_op,
551 initial_value,
552 at::cuda::getCurrentCUDAStream());
553 } else if (reduction == ReductionType::PROD) {
554 CustomProd prod_op{};
555 CUB_WRAPPER(
556 cub::DeviceSegmentedReduce::Reduce,
557 data_data_ptr,
558 output_data_ptr,
559 segment_count,
560 offsets_data_ptr,
561 offsets_data_ptr + 1,
562 prod_op,
563 initial_value,
564 at::cuda::getCurrentCUDAStream());
565 }
566 }
567 });
568 }));
569
570 return output;
571 }
572
_segment_reduce_lengths_cuda_kernel(ReductionType reduction,const Tensor & data,const Tensor & lengths,int64_t axis,const std::optional<Scalar> & initial)573 Tensor _segment_reduce_lengths_cuda_kernel(
574 ReductionType reduction,
575 const Tensor& data,
576 const Tensor& lengths,
577 int64_t axis,
578 const std::optional<Scalar>& initial) {
579 return _segment_reduce_lengths_offsets_cuda_kernel(
580 reduction, data, lengths, axis, initial, /*is_offsets_like=*/false);
581 }
582
_segment_reduce_offsets_cuda_kernel(ReductionType reduction,const Tensor & data,const Tensor & offsets,int64_t axis,const std::optional<Scalar> & initial)583 Tensor _segment_reduce_offsets_cuda_kernel(
584 ReductionType reduction,
585 const Tensor& data,
586 const Tensor& offsets,
587 int64_t axis,
588 const std::optional<Scalar>& initial) {
589 return _segment_reduce_lengths_offsets_cuda_kernel(
590 reduction, data, offsets, axis, initial, /*is_offsets_like=*/true);
591 }
592
593 REGISTER_DISPATCH(_segment_reduce_lengths_stub, &_segment_reduce_lengths_cuda_kernel);
594 REGISTER_DISPATCH(_segment_reduce_offsets_stub, &_segment_reduce_offsets_cuda_kernel);
595 REGISTER_DISPATCH(
596 _segment_reduce_lengths_backward_stub,
597 &_segment_reduce_lengths_backward_cuda_kernel);
598 REGISTER_DISPATCH(
599 _segment_reduce_offsets_backward_stub,
600 &_segment_reduce_offsets_backward_cuda_kernel);
601
602 } // namespace at::native
603