1 #pragma once 2 3 #include <c10/util/ArrayRef.h> 4 #include <c10/util/irange.h> 5 #include <optional> 6 7 #include <vector> 8 9 namespace torch { 10 namespace nn { 11 namespace modules { 12 namespace utils { 13 14 // Reverse the order of `t` and repeat each element for `n` times. 15 // This can be used to translate padding arg used by Conv and Pooling modules 16 // to the ones used by `F::pad`. 17 // 18 // This mirrors `_reverse_repeat_tuple` in `torch/nn/modules/utils.py`. _reverse_repeat_vector(at::ArrayRef<int64_t> t,int64_t n)19inline std::vector<int64_t> _reverse_repeat_vector( 20 at::ArrayRef<int64_t> t, 21 int64_t n) { 22 TORCH_INTERNAL_ASSERT(n >= 0); 23 std::vector<int64_t> ret; 24 ret.reserve(t.size() * n); 25 for (auto rit = t.rbegin(); rit != t.rend(); ++rit) { 26 for (const auto i : c10::irange(n)) { 27 (void)i; // Suppress unused variable 28 ret.emplace_back(*rit); 29 } 30 } 31 return ret; 32 } 33 _list_with_default(torch::ArrayRef<std::optional<int64_t>> out_size,torch::IntArrayRef defaults)34inline std::vector<int64_t> _list_with_default( 35 torch::ArrayRef<std::optional<int64_t>> out_size, 36 torch::IntArrayRef defaults) { 37 TORCH_CHECK( 38 defaults.size() > out_size.size(), 39 "Input dimension should be at least ", 40 out_size.size() + 1); 41 std::vector<int64_t> ret; 42 torch::IntArrayRef defaults_slice = 43 defaults.slice(defaults.size() - out_size.size(), out_size.size()); 44 for (const auto i : c10::irange(out_size.size())) { 45 auto v = out_size.at(i); 46 auto d = defaults_slice.at(i); 47 ret.emplace_back(v.has_value() ? v.value() : d); 48 } 49 return ret; 50 } 51 52 } // namespace utils 53 } // namespace modules 54 } // namespace nn 55 } // namespace torch 56