// Copyright 2004-present Facebook. All Rights Reserved. #define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include namespace at::native::upsample { TORCH_API c10::SmallVector compute_output_size( c10::IntArrayRef input_size, // Full input tensor size. at::OptionalIntArrayRef output_size, std::optional> scale_factors) { const auto spatial_dimensions = static_cast(input_size.size()) - 2; if (output_size) { TORCH_CHECK(!scale_factors, "Must specify exactly one of output_size and scale_factors"); TORCH_CHECK(static_cast(output_size->size()) == spatial_dimensions); return {output_size->data(), output_size->data() + output_size->size()}; } if (scale_factors) { TORCH_CHECK(!output_size, "Must specify exactly one of output_size and scale_factors"); TORCH_CHECK(static_cast(scale_factors->size()) == spatial_dimensions); c10::SmallVector ret; for (const auto i : c10::irange(spatial_dimensions)) { const double odim = static_cast(input_size[i+2]) * scale_factors.value()[i]; ret.push_back(c10::checked_convert(odim, "int64_t")); } return ret; } TORCH_CHECK(false, "Must specify exactly one of output_size and scale_factors"); } } // namespace at::native::upsample