xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/invalid_arguments.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/invalid_arguments.h>
2 
3 #include <torch/csrc/utils/python_strings.h>
4 
5 #include <c10/util/irange.h>
6 
7 #include <algorithm>
8 #include <memory>
9 #include <unordered_map>
10 
11 namespace torch {
12 
13 namespace {
14 
py_typename(PyObject * object)15 std::string py_typename(PyObject* object) {
16   return Py_TYPE(object)->tp_name;
17 }
18 
19 struct Type {
20   Type() = default;
21   Type(const Type&) = default;
22   Type& operator=(const Type&) = default;
23   Type(Type&&) noexcept = default;
24   Type& operator=(Type&&) noexcept = default;
25   virtual bool is_matching(PyObject* object) = 0;
26   virtual ~Type() = default;
27 };
28 
29 struct SimpleType : public Type {
SimpleTypetorch::__anon9622be550111::SimpleType30   SimpleType(std::string& name) : name(name){};
31 
is_matchingtorch::__anon9622be550111::SimpleType32   bool is_matching(PyObject* object) override {
33     return py_typename(object) == name;
34   }
35 
36   std::string name;
37 };
38 
39 struct MultiType : public Type {
MultiTypetorch::__anon9622be550111::MultiType40   MultiType(std::initializer_list<std::string> accepted_types)
41       : types(accepted_types){};
42 
is_matchingtorch::__anon9622be550111::MultiType43   bool is_matching(PyObject* object) override {
44     auto it = std::find(types.begin(), types.end(), py_typename(object));
45     return it != types.end();
46   }
47 
48   std::vector<std::string> types;
49 };
50 
51 struct NullableType : public Type {
NullableTypetorch::__anon9622be550111::NullableType52   NullableType(std::unique_ptr<Type> type) : type(std::move(type)){};
53 
is_matchingtorch::__anon9622be550111::NullableType54   bool is_matching(PyObject* object) override {
55     return object == Py_None || type->is_matching(object);
56   }
57 
58   std::unique_ptr<Type> type;
59 };
60 
61 struct TupleType : public Type {
TupleTypetorch::__anon9622be550111::TupleType62   TupleType(std::vector<std::unique_ptr<Type>> types)
63       : types(std::move(types)){};
64 
is_matchingtorch::__anon9622be550111::TupleType65   bool is_matching(PyObject* object) override {
66     if (!PyTuple_Check(object))
67       return false;
68     auto num_elements = PyTuple_GET_SIZE(object);
69     if (num_elements != (long)types.size())
70       return false;
71     for (const auto i : c10::irange(num_elements)) {
72       if (!types[i]->is_matching(PyTuple_GET_ITEM(object, i)))
73         return false;
74     }
75     return true;
76   }
77 
78   std::vector<std::unique_ptr<Type>> types;
79 };
80 
81 struct SequenceType : public Type {
SequenceTypetorch::__anon9622be550111::SequenceType82   SequenceType(std::unique_ptr<Type> type) : type(std::move(type)){};
83 
is_matchingtorch::__anon9622be550111::SequenceType84   bool is_matching(PyObject* object) override {
85     if (!PySequence_Check(object))
86       return false;
87     auto num_elements = PySequence_Length(object);
88     for (const auto i : c10::irange(num_elements)) {
89       if (!type->is_matching(
90               py::reinterpret_steal<py::object>(PySequence_GetItem(object, i))
91                   .ptr()))
92         return false;
93     }
94     return true;
95   }
96 
97   std::unique_ptr<Type> type;
98 };
99 
100 struct Argument {
Argumenttorch::__anon9622be550111::Argument101   Argument(std::string name, std::unique_ptr<Type> type)
102       : name(std::move(name)), type(std::move(type)){};
103 
104   std::string name;
105   std::unique_ptr<Type> type;
106 };
107 
108 struct Option {
Optiontorch::__anon9622be550111::Option109   Option(std::vector<Argument> arguments, bool is_variadic, bool has_out)
110       : arguments(std::move(arguments)),
111         is_variadic(is_variadic),
112         has_out(has_out){};
Optiontorch::__anon9622be550111::Option113   Option(bool is_variadic, bool has_out)
114       : arguments(), is_variadic(is_variadic), has_out(has_out){};
115   Option(const Option&) = delete;
116   Option(Option&& other) noexcept = default;
117   Option& operator=(const Option&) = delete;
118   Option& operator=(Option&&) = delete;
119 
120   std::vector<Argument> arguments;
121   bool is_variadic;
122   bool has_out;
123 };
124 
_splitString(const std::string & s,const std::string & delim)125 std::vector<std::string> _splitString(
126     const std::string& s,
127     const std::string& delim) {
128   std::vector<std::string> tokens;
129   size_t start = 0;
130   size_t end = 0;
131   while ((end = s.find(delim, start)) != std::string::npos) {
132     tokens.push_back(s.substr(start, end - start));
133     start = end + delim.length();
134   }
135   tokens.push_back(s.substr(start));
136   return tokens;
137 }
138 
_buildType(std::string type_name,bool is_nullable)139 std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
140   std::unique_ptr<Type> result;
141   if (type_name == "float") {
142     result = std::make_unique<MultiType>(MultiType{"float", "int", "long"});
143   } else if (type_name == "int") {
144     result = std::make_unique<MultiType>(MultiType{"int", "long"});
145   } else if (type_name.find("tuple[") == 0) {
146     auto type_list = type_name.substr(6);
147     type_list.pop_back();
148     std::vector<std::unique_ptr<Type>> types;
149     for (auto& type : _splitString(type_list, ","))
150       types.emplace_back(_buildType(type, false));
151     result = std::make_unique<TupleType>(std::move(types));
152   } else if (type_name.find("sequence[") == 0) {
153     auto subtype = type_name.substr(9);
154     subtype.pop_back();
155     result = std::make_unique<SequenceType>(_buildType(subtype, false));
156   } else {
157     result = std::make_unique<SimpleType>(type_name);
158   }
159   if (is_nullable)
160     result = std::make_unique<NullableType>(std::move(result));
161   return result;
162 }
163 
_parseOption(const std::string & _option_str,const std::unordered_map<std::string,PyObject * > & kwargs)164 std::pair<Option, std::string> _parseOption(
165     const std::string& _option_str,
166     const std::unordered_map<std::string, PyObject*>& kwargs) {
167   if (_option_str == "no arguments")
168     return std::pair<Option, std::string>(Option(false, false), _option_str);
169   bool has_out = false;
170   std::vector<Argument> arguments;
171   std::string printable_option = _option_str;
172   std::string option_str = _option_str.substr(1, _option_str.length() - 2);
173 
174   /// XXX: this is a hack only for the out arg in TensorMethods
175   auto out_pos = printable_option.find('#');
176   if (out_pos != std::string::npos) {
177     if (kwargs.count("out") > 0) {
178       std::string kwonly_part = printable_option.substr(out_pos + 1);
179       printable_option.erase(out_pos);
180       printable_option += "*, ";
181       printable_option += kwonly_part;
182     } else if (out_pos >= 2) {
183       printable_option.erase(out_pos - 2);
184       printable_option += ")";
185     } else {
186       printable_option.erase(out_pos);
187       printable_option += ")";
188     }
189     has_out = true;
190   }
191 
192   for (auto& arg : _splitString(option_str, ", ")) {
193     bool is_nullable = false;
194     auto type_start_idx = 0;
195     if (arg[type_start_idx] == '#') {
196       type_start_idx++;
197     }
198     if (arg[type_start_idx] == '[') {
199       is_nullable = true;
200       type_start_idx++;
201       arg.erase(arg.length() - std::string(" or None]").length());
202     }
203 
204     auto type_end_idx = arg.find_last_of(' ');
205     auto name_start_idx = type_end_idx + 1;
206 
207     // "type ... name" => "type ... name"
208     //          ^              ^
209     auto dots_idx = arg.find("...");
210     if (dots_idx != std::string::npos)
211       type_end_idx -= 4;
212 
213     std::string type_name =
214         arg.substr(type_start_idx, type_end_idx - type_start_idx);
215     std::string name = arg.substr(name_start_idx);
216 
217     arguments.emplace_back(name, _buildType(type_name, is_nullable));
218   }
219 
220   bool is_variadic = option_str.find("...") != std::string::npos;
221   return std::pair<Option, std::string>(
222       Option(std::move(arguments), is_variadic, has_out),
223       std::move(printable_option));
224 }
225 
_argcountMatch(const Option & option,const std::vector<PyObject * > & arguments,const std::unordered_map<std::string,PyObject * > & kwargs)226 bool _argcountMatch(
227     const Option& option,
228     const std::vector<PyObject*>& arguments,
229     const std::unordered_map<std::string, PyObject*>& kwargs) {
230   auto num_expected = option.arguments.size();
231   auto num_got = arguments.size() + kwargs.size();
232   // Note: variadic functions don't accept kwargs, so it's ok
233   if (option.has_out && kwargs.count("out") == 0)
234     num_expected--;
235   return num_got == num_expected ||
236       (option.is_variadic && num_got > num_expected);
237 }
238 
_formattedArgDesc(const Option & option,const std::vector<PyObject * > & arguments,const std::unordered_map<std::string,PyObject * > & kwargs)239 std::string _formattedArgDesc(
240     const Option& option,
241     const std::vector<PyObject*>& arguments,
242     const std::unordered_map<std::string, PyObject*>& kwargs) {
243   std::string red;
244   std::string reset_red;
245   std::string green;
246   std::string reset_green;
247   if (isatty(1) && isatty(2)) {
248     red = "\33[31;1m";
249     reset_red = "\33[0m";
250     green = "\33[32;1m";
251     reset_green = "\33[0m";
252   } else {
253     red = "!";
254     reset_red = "!";
255     green = "";
256     reset_green = "";
257   }
258 
259   auto num_args = arguments.size() + kwargs.size();
260   std::string result = "(";
261   for (const auto i : c10::irange(num_args)) {
262     bool is_kwarg = i >= arguments.size();
263     PyObject* arg =
264         is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i];
265 
266     bool is_matching = false;
267     if (i < option.arguments.size()) {
268       is_matching = option.arguments[i].type->is_matching(arg);
269     } else if (option.is_variadic) {
270       is_matching = option.arguments.back().type->is_matching(arg);
271     }
272 
273     if (is_matching)
274       result += green;
275     else
276       result += red;
277     if (is_kwarg)
278       result += option.arguments[i].name + "=";
279     bool is_tuple = PyTuple_Check(arg);
280     if (is_tuple || PyList_Check(arg)) {
281       result += py_typename(arg) + " of ";
282       auto num_elements = PySequence_Length(arg);
283       if (is_tuple) {
284         result += "(";
285       } else {
286         result += "[";
287       }
288       for (const auto i : c10::irange(num_elements)) {
289         if (i != 0) {
290           result += ", ";
291         }
292         result += py_typename(
293             py::reinterpret_steal<py::object>(PySequence_GetItem(arg, i))
294                 .ptr());
295       }
296       if (is_tuple) {
297         if (num_elements == 1) {
298           result += ",";
299         }
300         result += ")";
301       } else {
302         result += "]";
303       }
304     } else {
305       result += py_typename(arg);
306     }
307     if (is_matching)
308       result += reset_green;
309     else
310       result += reset_red;
311     result += ", ";
312   }
313   if (!arguments.empty())
314     result.erase(result.length() - 2);
315   result += ")";
316   return result;
317 }
318 
_argDesc(const std::vector<PyObject * > & arguments,const std::unordered_map<std::string,PyObject * > & kwargs)319 std::string _argDesc(
320     const std::vector<PyObject*>& arguments,
321     const std::unordered_map<std::string, PyObject*>& kwargs) {
322   std::string result = "(";
323   for (auto& arg : arguments)
324     result += std::string(py_typename(arg)) + ", ";
325   for (auto& kwarg : kwargs)
326     result += kwarg.first + "=" + py_typename(kwarg.second) + ", ";
327   if (!arguments.empty())
328     result.erase(result.length() - 2);
329   result += ")";
330   return result;
331 }
332 
_tryMatchKwargs(const Option & option,const std::unordered_map<std::string,PyObject * > & kwargs)333 std::vector<std::string> _tryMatchKwargs(
334     const Option& option,
335     const std::unordered_map<std::string, PyObject*>& kwargs) {
336   std::vector<std::string> unmatched;
337   // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
338   int64_t start_idx = option.arguments.size() - kwargs.size();
339   if (option.has_out && kwargs.count("out") == 0)
340     start_idx--;
341   if (start_idx < 0)
342     start_idx = 0;
343   for (auto& entry : kwargs) {
344     bool found = false;
345     for (unsigned int i = start_idx; i < option.arguments.size(); i++) {
346       if (option.arguments[i].name == entry.first) {
347         found = true;
348         break;
349       }
350     }
351     if (!found)
352       unmatched.push_back(entry.first);
353   }
354   return unmatched;
355 }
356 
357 } // anonymous namespace
358 
format_invalid_args(PyObject * given_args,PyObject * given_kwargs,const std::string & function_name,const std::vector<std::string> & options)359 std::string format_invalid_args(
360     PyObject* given_args,
361     PyObject* given_kwargs,
362     const std::string& function_name,
363     const std::vector<std::string>& options) {
364   std::vector<PyObject*> args;
365   std::unordered_map<std::string, PyObject*> kwargs;
366   std::string error_msg;
367   error_msg.reserve(2000);
368   error_msg += function_name;
369   error_msg += " received an invalid combination of arguments - ";
370 
371   Py_ssize_t num_args = PyTuple_Size(given_args);
372   for (const auto i : c10::irange(num_args)) {
373     PyObject* arg = PyTuple_GET_ITEM(given_args, i);
374     args.push_back(arg);
375   }
376 
377   bool has_kwargs = given_kwargs && PyDict_Size(given_kwargs) > 0;
378   if (has_kwargs) {
379     PyObject *key = nullptr, *value = nullptr;
380     Py_ssize_t pos = 0;
381 
382     while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
383       kwargs.emplace(THPUtils_unpackString(key), value);
384     }
385   }
386 
387   if (options.size() == 1) {
388     auto pair = _parseOption(options[0], kwargs);
389     auto& option = pair.first;
390     auto& option_str = pair.second;
391     std::vector<std::string> unmatched_kwargs;
392     if (has_kwargs)
393       unmatched_kwargs = _tryMatchKwargs(option, kwargs);
394     if (!unmatched_kwargs.empty()) {
395       error_msg += "got unrecognized keyword arguments: ";
396       for (auto& kwarg : unmatched_kwargs)
397         error_msg += kwarg + ", ";
398       error_msg.erase(error_msg.length() - 2);
399     } else {
400       error_msg += "got ";
401       if (_argcountMatch(option, args, kwargs)) {
402         error_msg += _formattedArgDesc(option, args, kwargs);
403       } else {
404         error_msg += _argDesc(args, kwargs);
405       }
406       error_msg += ", but expected ";
407       error_msg += option_str;
408     }
409   } else {
410     error_msg += "got ";
411     error_msg += _argDesc(args, kwargs);
412     error_msg += ", but expected one of:\n";
413     for (auto& option_str : options) {
414       auto pair = _parseOption(option_str, kwargs);
415       auto& option = pair.first;
416       auto& printable_option_str = pair.second;
417       error_msg += " * ";
418       error_msg += printable_option_str;
419       error_msg += "\n";
420       if (_argcountMatch(option, args, kwargs)) {
421         std::vector<std::string> unmatched_kwargs;
422         if (has_kwargs)
423           unmatched_kwargs = _tryMatchKwargs(option, kwargs);
424         if (!unmatched_kwargs.empty()) {
425           error_msg +=
426               "      didn't match because some of the keywords were incorrect: ";
427           for (auto& kwarg : unmatched_kwargs)
428             error_msg += kwarg + ", ";
429           error_msg.erase(error_msg.length() - 2);
430           error_msg += "\n";
431         } else {
432           error_msg +=
433               "      didn't match because some of the arguments have invalid types: ";
434           error_msg += _formattedArgDesc(option, args, kwargs);
435           error_msg += "\n";
436         }
437       }
438     }
439   }
440   return error_msg;
441 }
442 
443 } // namespace torch
444