1 // Copyright 2004-present Facebook. All Rights Reserved.
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3
4 #include <ATen/native/UpSample.h>
5 #include <c10/util/irange.h>
6 #include <c10/util/TypeCast.h>
7
8 namespace at::native::upsample {
9
compute_output_size(c10::IntArrayRef input_size,at::OptionalIntArrayRef output_size,std::optional<c10::ArrayRef<double>> scale_factors)10 TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
11 c10::IntArrayRef input_size, // Full input tensor size.
12 at::OptionalIntArrayRef output_size,
13 std::optional<c10::ArrayRef<double>> scale_factors) {
14 const auto spatial_dimensions = static_cast<int64_t>(input_size.size()) - 2;
15 if (output_size) {
16 TORCH_CHECK(!scale_factors, "Must specify exactly one of output_size and scale_factors");
17 TORCH_CHECK(static_cast<int64_t>(output_size->size()) == spatial_dimensions);
18 return {output_size->data(), output_size->data() + output_size->size()};
19 }
20 if (scale_factors) {
21 TORCH_CHECK(!output_size, "Must specify exactly one of output_size and scale_factors");
22 TORCH_CHECK(static_cast<int64_t>(scale_factors->size()) == spatial_dimensions);
23 c10::SmallVector<int64_t, 3> ret;
24 for (const auto i : c10::irange(spatial_dimensions)) {
25 const double odim = static_cast<double>(input_size[i+2]) * scale_factors.value()[i];
26 ret.push_back(c10::checked_convert<int64_t>(odim, "int64_t"));
27 }
28 return ret;
29 }
30 TORCH_CHECK(false, "Must specify exactly one of output_size and scale_factors");
31 }
32
33 } // namespace at::native::upsample
34