xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/UpSample.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 
4 #include <ATen/native/UpSample.h>
5 #include <c10/util/irange.h>
6 #include <c10/util/TypeCast.h>
7 
8 namespace at::native::upsample {
9 
compute_output_size(c10::IntArrayRef input_size,at::OptionalIntArrayRef output_size,std::optional<c10::ArrayRef<double>> scale_factors)10 TORCH_API c10::SmallVector<int64_t, 3> compute_output_size(
11     c10::IntArrayRef input_size,  // Full input tensor size.
12     at::OptionalIntArrayRef output_size,
13     std::optional<c10::ArrayRef<double>> scale_factors) {
14   const auto spatial_dimensions = static_cast<int64_t>(input_size.size()) - 2;
15   if (output_size) {
16     TORCH_CHECK(!scale_factors, "Must specify exactly one of output_size and scale_factors");
17     TORCH_CHECK(static_cast<int64_t>(output_size->size()) == spatial_dimensions);
18     return {output_size->data(), output_size->data() + output_size->size()};
19   }
20   if (scale_factors) {
21     TORCH_CHECK(!output_size, "Must specify exactly one of output_size and scale_factors");
22     TORCH_CHECK(static_cast<int64_t>(scale_factors->size()) == spatial_dimensions);
23     c10::SmallVector<int64_t, 3> ret;
24     for (const auto i : c10::irange(spatial_dimensions)) {
25       const double odim = static_cast<double>(input_size[i+2]) * scale_factors.value()[i];
26       ret.push_back(c10::checked_convert<int64_t>(odim, "int64_t"));
27     }
28     return ret;
29   }
30   TORCH_CHECK(false, "Must specify exactly one of output_size and scale_factors");
31 }
32 
33 } // namespace at::native::upsample
34