xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/Repeat.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorOperators.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/empty.h>
10 #include <ATen/ops/empty_like.h>
11 #endif
12 
13 namespace at::native {
14 
15 template <
16     typename index_t,
17     void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
repeat_interleave_common(const Tensor & repeats,std::optional<int64_t> output_size)18 static inline Tensor repeat_interleave_common(
19     const Tensor& repeats,
20     std::optional<int64_t> output_size) {
21   TORCH_CHECK(
22       repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
23   TORCH_CHECK(
24       repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
25       "repeats has to be Long or Int tensor");
26   if (repeats.size(0) == 0) {
27     return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
28   }
29   Tensor repeats_ = repeats.contiguous();
30   Tensor cumsum = repeats.cumsum(0);
31   int64_t total = 0;
32   if (output_size.has_value()) {
33     total = output_size.value();
34   } else {
35     total = cumsum[-1].item<int64_t>();
36     TORCH_CHECK(
37         (repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
38   }
39 
40   Tensor result = at::empty({total}, repeats.options());
41   const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
42   const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
43   index_t* result_ptr = result.data_ptr<index_t>();
44   compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
45   return result;
46 }
47 
48 } // namespace at::native
49