xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/nn/modules/utils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/ArrayRef.h>
4 #include <c10/util/irange.h>
5 #include <optional>
6 
7 #include <vector>
8 
9 namespace torch {
10 namespace nn {
11 namespace modules {
12 namespace utils {
13 
14 // Reverse the order of `t` and repeat each element for `n` times.
15 // This can be used to translate padding arg used by Conv and Pooling modules
16 // to the ones used by `F::pad`.
17 //
18 // This mirrors `_reverse_repeat_tuple` in `torch/nn/modules/utils.py`.
_reverse_repeat_vector(at::ArrayRef<int64_t> t,int64_t n)19 inline std::vector<int64_t> _reverse_repeat_vector(
20     at::ArrayRef<int64_t> t,
21     int64_t n) {
22   TORCH_INTERNAL_ASSERT(n >= 0);
23   std::vector<int64_t> ret;
24   ret.reserve(t.size() * n);
25   for (auto rit = t.rbegin(); rit != t.rend(); ++rit) {
26     for (const auto i : c10::irange(n)) {
27       (void)i; // Suppress unused variable
28       ret.emplace_back(*rit);
29     }
30   }
31   return ret;
32 }
33 
_list_with_default(torch::ArrayRef<std::optional<int64_t>> out_size,torch::IntArrayRef defaults)34 inline std::vector<int64_t> _list_with_default(
35     torch::ArrayRef<std::optional<int64_t>> out_size,
36     torch::IntArrayRef defaults) {
37   TORCH_CHECK(
38       defaults.size() > out_size.size(),
39       "Input dimension should be at least ",
40       out_size.size() + 1);
41   std::vector<int64_t> ret;
42   torch::IntArrayRef defaults_slice =
43       defaults.slice(defaults.size() - out_size.size(), out_size.size());
44   for (const auto i : c10::irange(out_size.size())) {
45     auto v = out_size.at(i);
46     auto d = defaults_slice.at(i);
47     ret.emplace_back(v.has_value() ? v.value() : d);
48   }
49   return ret;
50 }
51 
52 } // namespace utils
53 } // namespace modules
54 } // namespace nn
55 } // namespace torch
56