#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #endif template static void compute_cpu( const index_t* repeat_ptr, const int64_t* cumsum_ptr, index_t* result_ptr, int64_t size, int64_t result_size) { TORCH_CHECK( (result_size == cumsum_ptr[size - 1]), "allocated size does not match required size"); at::parallel_for(0, size, 1, [&](int64_t i_begin, int64_t i_end) { for (const auto i : c10::irange(i_begin, i_end)) { int64_t end = cumsum_ptr[i]; index_t size = repeat_ptr[i]; TORCH_CHECK((size >= 0), "repeats can not be negative"); int64_t start = end - size; for (const auto j : c10::irange(start, end)) { result_ptr[j] = i; } } }); } namespace at::native { Tensor repeat_interleave_cpu( const Tensor& repeat, std::optional output_size) { Tensor output; AT_DISPATCH_INDEX_TYPES(repeat.scalar_type(), "repeat_interleave_cpu", [&]() { output = repeat_interleave_common>( repeat, output_size); }); return output; } Tensor repeat_interleave_symint( const Tensor& self, const Tensor& repeats, std::optional dim, std::optional output_size) { Tensor input = self; // Store conj and neg bits const auto conj = input.is_conj(); if (conj) { input = input.conj(); } const auto neg = input.is_neg(); if (neg) { input = input._neg_view(); } if (!dim) { input = input.flatten(); dim = 0; } Tensor repeats_ = repeats; if (repeats.dim() == 0 || (repeats.dim() == 1 && repeats.sym_size(0) == 1)) { repeats_ = repeats.reshape({1}).expand_symint({input.sym_size(dim.value())}); } else if (repeats.dim() == 1) { TORCH_CHECK( repeats.sym_size(0) == input.sym_size(dim.value()), "repeats must have the same size as input along dim, but got repeats.size(0) = ", repeats.sym_size(0), " and input.size(", dim.value(), ") = ", input.sym_size(dim.value()) ); } else { AT_ERROR("repeats must be 0-dim or 1-dim tensor"); } auto ret = input.index_select( dim.value(), at::repeat_interleave_symint(repeats_, std::move(output_size))); // Restore conj and neg bits if (conj) { ret = ret.conj(); } if (neg) { ret = ret._neg_view(); } return ret; } Tensor repeat_interleave_symint( const Tensor& self, c10::SymInt repeats, std::optional dim_opt, std::optional output_size) { Tensor input = dim_opt ? self : self.flatten(); int64_t dim = c10::maybe_wrap_dim(dim_opt.value_or(0), self.dim()); TORCH_CHECK(repeats >= 0, "Repeats must be non-negative"); input = input.unsqueeze(dim + 1); auto expand_shape = input.sym_sizes().vec(); expand_shape[dim + 1] = repeats; input = input.expand_symint(expand_shape); // This argument doesn't really make sense for the scalar overload, but exists // for consistency with the tensor overload if (output_size) { auto calculated_size = (repeats * expand_shape[dim]).guard_int(__FILE__, __LINE__); TORCH_CHECK(*output_size == calculated_size, "repeat_interleave: Invalid output_size, expected ", calculated_size, " but got ", *output_size); } return input.clone(at::MemoryFormat::Contiguous).flatten(dim, dim + 1); } } // namespace at::native