1 #pragma once
2
3 #include <c10/util/ArrayRef.h>
4 #include <vector>
5
6 namespace at {
7 namespace native {
8
9 template <typename T>
_expand_param_if_needed(ArrayRef<T> list_param,const char * param_name,int64_t expected_dim)10 inline std::vector<T> _expand_param_if_needed(
11 ArrayRef<T> list_param,
12 const char* param_name,
13 int64_t expected_dim) {
14 if (list_param.size() == 1) {
15 return std::vector<T>(expected_dim, list_param[0]);
16 } else if ((int64_t)list_param.size() != expected_dim) {
17 std::ostringstream ss;
18 ss << "expected " << param_name << " to be a single integer value or a "
19 << "list of " << expected_dim << " values to match the convolution "
20 << "dimensions, but got " << param_name << "=" << list_param;
21 AT_ERROR(ss.str());
22 } else {
23 return list_param.vec();
24 }
25 }
26
expand_param_if_needed(IntArrayRef list_param,const char * param_name,int64_t expected_dim)27 inline std::vector<int64_t> expand_param_if_needed(
28 IntArrayRef list_param,
29 const char* param_name,
30 int64_t expected_dim) {
31 return _expand_param_if_needed(list_param, param_name, expected_dim);
32 }
33
expand_param_if_needed(SymIntArrayRef list_param,const char * param_name,int64_t expected_dim)34 inline std::vector<c10::SymInt> expand_param_if_needed(
35 SymIntArrayRef list_param,
36 const char* param_name,
37 int64_t expected_dim) {
38 return _expand_param_if_needed(list_param, param_name, expected_dim);
39 }
40
41 } // namespace native
42 } // namespace at
43