xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/utils/ParamUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/ArrayRef.h>
4 #include <vector>
5 
6 namespace at {
7 namespace native {
8 
9 template <typename T>
_expand_param_if_needed(ArrayRef<T> list_param,const char * param_name,int64_t expected_dim)10 inline std::vector<T> _expand_param_if_needed(
11     ArrayRef<T> list_param,
12     const char* param_name,
13     int64_t expected_dim) {
14   if (list_param.size() == 1) {
15     return std::vector<T>(expected_dim, list_param[0]);
16   } else if ((int64_t)list_param.size() != expected_dim) {
17     std::ostringstream ss;
18     ss << "expected " << param_name << " to be a single integer value or a "
19        << "list of " << expected_dim << " values to match the convolution "
20        << "dimensions, but got " << param_name << "=" << list_param;
21     AT_ERROR(ss.str());
22   } else {
23     return list_param.vec();
24   }
25 }
26 
expand_param_if_needed(IntArrayRef list_param,const char * param_name,int64_t expected_dim)27 inline std::vector<int64_t> expand_param_if_needed(
28     IntArrayRef list_param,
29     const char* param_name,
30     int64_t expected_dim) {
31   return _expand_param_if_needed(list_param, param_name, expected_dim);
32 }
33 
expand_param_if_needed(SymIntArrayRef list_param,const char * param_name,int64_t expected_dim)34 inline std::vector<c10::SymInt> expand_param_if_needed(
35     SymIntArrayRef list_param,
36     const char* param_name,
37     int64_t expected_dim) {
38   return _expand_param_if_needed(list_param, param_name, expected_dim);
39 }
40 
41 } // namespace native
42 } // namespace at
43