1 #pragma once
2
3 #include <ATen/core/Tensor.h>
4 #include <ATen/TensorOperators.h>
5
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/Functions.h>
8 #else
9 #include <ATen/ops/empty.h>
10 #include <ATen/ops/empty_like.h>
11 #endif
12
13 namespace at::native {
14
15 template <
16 typename index_t,
17 void compute(const index_t*, const int64_t*, index_t*, int64_t, int64_t)>
repeat_interleave_common(const Tensor & repeats,std::optional<int64_t> output_size)18 static inline Tensor repeat_interleave_common(
19 const Tensor& repeats,
20 std::optional<int64_t> output_size) {
21 TORCH_CHECK(
22 repeats.dim() == 1, "repeat_interleave only accept 1D vector as repeat");
23 TORCH_CHECK(
24 repeats.scalar_type() == at::kLong || repeats.scalar_type() == at::kInt,
25 "repeats has to be Long or Int tensor");
26 if (repeats.size(0) == 0) {
27 return at::empty_like(repeats, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
28 }
29 Tensor repeats_ = repeats.contiguous();
30 Tensor cumsum = repeats.cumsum(0);
31 int64_t total = 0;
32 if (output_size.has_value()) {
33 total = output_size.value();
34 } else {
35 total = cumsum[-1].item<int64_t>();
36 TORCH_CHECK(
37 (repeats >= 0).all().item<uint8_t>(), "repeats can not be negative");
38 }
39
40 Tensor result = at::empty({total}, repeats.options());
41 const index_t* repeat_ptr = repeats_.const_data_ptr<index_t>();
42 const int64_t* cumsum_ptr = cumsum.const_data_ptr<int64_t>();
43 index_t* result_ptr = result.data_ptr<index_t>();
44 compute(repeat_ptr, cumsum_ptr, result_ptr, repeats.size(0), total);
45 return result;
46 }
47
48 } // namespace at::native
49