1 #include <torch/csrc/utils/invalid_arguments.h>
2
3 #include <torch/csrc/utils/python_strings.h>
4
5 #include <c10/util/irange.h>
6
7 #include <algorithm>
8 #include <memory>
9 #include <unordered_map>
10
11 namespace torch {
12
13 namespace {
14
py_typename(PyObject * object)15 std::string py_typename(PyObject* object) {
16 return Py_TYPE(object)->tp_name;
17 }
18
19 struct Type {
20 Type() = default;
21 Type(const Type&) = default;
22 Type& operator=(const Type&) = default;
23 Type(Type&&) noexcept = default;
24 Type& operator=(Type&&) noexcept = default;
25 virtual bool is_matching(PyObject* object) = 0;
26 virtual ~Type() = default;
27 };
28
29 struct SimpleType : public Type {
SimpleTypetorch::__anon9622be550111::SimpleType30 SimpleType(std::string& name) : name(name){};
31
is_matchingtorch::__anon9622be550111::SimpleType32 bool is_matching(PyObject* object) override {
33 return py_typename(object) == name;
34 }
35
36 std::string name;
37 };
38
39 struct MultiType : public Type {
MultiTypetorch::__anon9622be550111::MultiType40 MultiType(std::initializer_list<std::string> accepted_types)
41 : types(accepted_types){};
42
is_matchingtorch::__anon9622be550111::MultiType43 bool is_matching(PyObject* object) override {
44 auto it = std::find(types.begin(), types.end(), py_typename(object));
45 return it != types.end();
46 }
47
48 std::vector<std::string> types;
49 };
50
51 struct NullableType : public Type {
NullableTypetorch::__anon9622be550111::NullableType52 NullableType(std::unique_ptr<Type> type) : type(std::move(type)){};
53
is_matchingtorch::__anon9622be550111::NullableType54 bool is_matching(PyObject* object) override {
55 return object == Py_None || type->is_matching(object);
56 }
57
58 std::unique_ptr<Type> type;
59 };
60
61 struct TupleType : public Type {
TupleTypetorch::__anon9622be550111::TupleType62 TupleType(std::vector<std::unique_ptr<Type>> types)
63 : types(std::move(types)){};
64
is_matchingtorch::__anon9622be550111::TupleType65 bool is_matching(PyObject* object) override {
66 if (!PyTuple_Check(object))
67 return false;
68 auto num_elements = PyTuple_GET_SIZE(object);
69 if (num_elements != (long)types.size())
70 return false;
71 for (const auto i : c10::irange(num_elements)) {
72 if (!types[i]->is_matching(PyTuple_GET_ITEM(object, i)))
73 return false;
74 }
75 return true;
76 }
77
78 std::vector<std::unique_ptr<Type>> types;
79 };
80
81 struct SequenceType : public Type {
SequenceTypetorch::__anon9622be550111::SequenceType82 SequenceType(std::unique_ptr<Type> type) : type(std::move(type)){};
83
is_matchingtorch::__anon9622be550111::SequenceType84 bool is_matching(PyObject* object) override {
85 if (!PySequence_Check(object))
86 return false;
87 auto num_elements = PySequence_Length(object);
88 for (const auto i : c10::irange(num_elements)) {
89 if (!type->is_matching(
90 py::reinterpret_steal<py::object>(PySequence_GetItem(object, i))
91 .ptr()))
92 return false;
93 }
94 return true;
95 }
96
97 std::unique_ptr<Type> type;
98 };
99
100 struct Argument {
Argumenttorch::__anon9622be550111::Argument101 Argument(std::string name, std::unique_ptr<Type> type)
102 : name(std::move(name)), type(std::move(type)){};
103
104 std::string name;
105 std::unique_ptr<Type> type;
106 };
107
108 struct Option {
Optiontorch::__anon9622be550111::Option109 Option(std::vector<Argument> arguments, bool is_variadic, bool has_out)
110 : arguments(std::move(arguments)),
111 is_variadic(is_variadic),
112 has_out(has_out){};
Optiontorch::__anon9622be550111::Option113 Option(bool is_variadic, bool has_out)
114 : arguments(), is_variadic(is_variadic), has_out(has_out){};
115 Option(const Option&) = delete;
116 Option(Option&& other) noexcept = default;
117 Option& operator=(const Option&) = delete;
118 Option& operator=(Option&&) = delete;
119
120 std::vector<Argument> arguments;
121 bool is_variadic;
122 bool has_out;
123 };
124
_splitString(const std::string & s,const std::string & delim)125 std::vector<std::string> _splitString(
126 const std::string& s,
127 const std::string& delim) {
128 std::vector<std::string> tokens;
129 size_t start = 0;
130 size_t end = 0;
131 while ((end = s.find(delim, start)) != std::string::npos) {
132 tokens.push_back(s.substr(start, end - start));
133 start = end + delim.length();
134 }
135 tokens.push_back(s.substr(start));
136 return tokens;
137 }
138
_buildType(std::string type_name,bool is_nullable)139 std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
140 std::unique_ptr<Type> result;
141 if (type_name == "float") {
142 result = std::make_unique<MultiType>(MultiType{"float", "int", "long"});
143 } else if (type_name == "int") {
144 result = std::make_unique<MultiType>(MultiType{"int", "long"});
145 } else if (type_name.find("tuple[") == 0) {
146 auto type_list = type_name.substr(6);
147 type_list.pop_back();
148 std::vector<std::unique_ptr<Type>> types;
149 for (auto& type : _splitString(type_list, ","))
150 types.emplace_back(_buildType(type, false));
151 result = std::make_unique<TupleType>(std::move(types));
152 } else if (type_name.find("sequence[") == 0) {
153 auto subtype = type_name.substr(9);
154 subtype.pop_back();
155 result = std::make_unique<SequenceType>(_buildType(subtype, false));
156 } else {
157 result = std::make_unique<SimpleType>(type_name);
158 }
159 if (is_nullable)
160 result = std::make_unique<NullableType>(std::move(result));
161 return result;
162 }
163
_parseOption(const std::string & _option_str,const std::unordered_map<std::string,PyObject * > & kwargs)164 std::pair<Option, std::string> _parseOption(
165 const std::string& _option_str,
166 const std::unordered_map<std::string, PyObject*>& kwargs) {
167 if (_option_str == "no arguments")
168 return std::pair<Option, std::string>(Option(false, false), _option_str);
169 bool has_out = false;
170 std::vector<Argument> arguments;
171 std::string printable_option = _option_str;
172 std::string option_str = _option_str.substr(1, _option_str.length() - 2);
173
174 /// XXX: this is a hack only for the out arg in TensorMethods
175 auto out_pos = printable_option.find('#');
176 if (out_pos != std::string::npos) {
177 if (kwargs.count("out") > 0) {
178 std::string kwonly_part = printable_option.substr(out_pos + 1);
179 printable_option.erase(out_pos);
180 printable_option += "*, ";
181 printable_option += kwonly_part;
182 } else if (out_pos >= 2) {
183 printable_option.erase(out_pos - 2);
184 printable_option += ")";
185 } else {
186 printable_option.erase(out_pos);
187 printable_option += ")";
188 }
189 has_out = true;
190 }
191
192 for (auto& arg : _splitString(option_str, ", ")) {
193 bool is_nullable = false;
194 auto type_start_idx = 0;
195 if (arg[type_start_idx] == '#') {
196 type_start_idx++;
197 }
198 if (arg[type_start_idx] == '[') {
199 is_nullable = true;
200 type_start_idx++;
201 arg.erase(arg.length() - std::string(" or None]").length());
202 }
203
204 auto type_end_idx = arg.find_last_of(' ');
205 auto name_start_idx = type_end_idx + 1;
206
207 // "type ... name" => "type ... name"
208 // ^ ^
209 auto dots_idx = arg.find("...");
210 if (dots_idx != std::string::npos)
211 type_end_idx -= 4;
212
213 std::string type_name =
214 arg.substr(type_start_idx, type_end_idx - type_start_idx);
215 std::string name = arg.substr(name_start_idx);
216
217 arguments.emplace_back(name, _buildType(type_name, is_nullable));
218 }
219
220 bool is_variadic = option_str.find("...") != std::string::npos;
221 return std::pair<Option, std::string>(
222 Option(std::move(arguments), is_variadic, has_out),
223 std::move(printable_option));
224 }
225
_argcountMatch(const Option & option,const std::vector<PyObject * > & arguments,const std::unordered_map<std::string,PyObject * > & kwargs)226 bool _argcountMatch(
227 const Option& option,
228 const std::vector<PyObject*>& arguments,
229 const std::unordered_map<std::string, PyObject*>& kwargs) {
230 auto num_expected = option.arguments.size();
231 auto num_got = arguments.size() + kwargs.size();
232 // Note: variadic functions don't accept kwargs, so it's ok
233 if (option.has_out && kwargs.count("out") == 0)
234 num_expected--;
235 return num_got == num_expected ||
236 (option.is_variadic && num_got > num_expected);
237 }
238
_formattedArgDesc(const Option & option,const std::vector<PyObject * > & arguments,const std::unordered_map<std::string,PyObject * > & kwargs)239 std::string _formattedArgDesc(
240 const Option& option,
241 const std::vector<PyObject*>& arguments,
242 const std::unordered_map<std::string, PyObject*>& kwargs) {
243 std::string red;
244 std::string reset_red;
245 std::string green;
246 std::string reset_green;
247 if (isatty(1) && isatty(2)) {
248 red = "\33[31;1m";
249 reset_red = "\33[0m";
250 green = "\33[32;1m";
251 reset_green = "\33[0m";
252 } else {
253 red = "!";
254 reset_red = "!";
255 green = "";
256 reset_green = "";
257 }
258
259 auto num_args = arguments.size() + kwargs.size();
260 std::string result = "(";
261 for (const auto i : c10::irange(num_args)) {
262 bool is_kwarg = i >= arguments.size();
263 PyObject* arg =
264 is_kwarg ? kwargs.at(option.arguments[i].name) : arguments[i];
265
266 bool is_matching = false;
267 if (i < option.arguments.size()) {
268 is_matching = option.arguments[i].type->is_matching(arg);
269 } else if (option.is_variadic) {
270 is_matching = option.arguments.back().type->is_matching(arg);
271 }
272
273 if (is_matching)
274 result += green;
275 else
276 result += red;
277 if (is_kwarg)
278 result += option.arguments[i].name + "=";
279 bool is_tuple = PyTuple_Check(arg);
280 if (is_tuple || PyList_Check(arg)) {
281 result += py_typename(arg) + " of ";
282 auto num_elements = PySequence_Length(arg);
283 if (is_tuple) {
284 result += "(";
285 } else {
286 result += "[";
287 }
288 for (const auto i : c10::irange(num_elements)) {
289 if (i != 0) {
290 result += ", ";
291 }
292 result += py_typename(
293 py::reinterpret_steal<py::object>(PySequence_GetItem(arg, i))
294 .ptr());
295 }
296 if (is_tuple) {
297 if (num_elements == 1) {
298 result += ",";
299 }
300 result += ")";
301 } else {
302 result += "]";
303 }
304 } else {
305 result += py_typename(arg);
306 }
307 if (is_matching)
308 result += reset_green;
309 else
310 result += reset_red;
311 result += ", ";
312 }
313 if (!arguments.empty())
314 result.erase(result.length() - 2);
315 result += ")";
316 return result;
317 }
318
_argDesc(const std::vector<PyObject * > & arguments,const std::unordered_map<std::string,PyObject * > & kwargs)319 std::string _argDesc(
320 const std::vector<PyObject*>& arguments,
321 const std::unordered_map<std::string, PyObject*>& kwargs) {
322 std::string result = "(";
323 for (auto& arg : arguments)
324 result += std::string(py_typename(arg)) + ", ";
325 for (auto& kwarg : kwargs)
326 result += kwarg.first + "=" + py_typename(kwarg.second) + ", ";
327 if (!arguments.empty())
328 result.erase(result.length() - 2);
329 result += ")";
330 return result;
331 }
332
_tryMatchKwargs(const Option & option,const std::unordered_map<std::string,PyObject * > & kwargs)333 std::vector<std::string> _tryMatchKwargs(
334 const Option& option,
335 const std::unordered_map<std::string, PyObject*>& kwargs) {
336 std::vector<std::string> unmatched;
337 // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
338 int64_t start_idx = option.arguments.size() - kwargs.size();
339 if (option.has_out && kwargs.count("out") == 0)
340 start_idx--;
341 if (start_idx < 0)
342 start_idx = 0;
343 for (auto& entry : kwargs) {
344 bool found = false;
345 for (unsigned int i = start_idx; i < option.arguments.size(); i++) {
346 if (option.arguments[i].name == entry.first) {
347 found = true;
348 break;
349 }
350 }
351 if (!found)
352 unmatched.push_back(entry.first);
353 }
354 return unmatched;
355 }
356
357 } // anonymous namespace
358
format_invalid_args(PyObject * given_args,PyObject * given_kwargs,const std::string & function_name,const std::vector<std::string> & options)359 std::string format_invalid_args(
360 PyObject* given_args,
361 PyObject* given_kwargs,
362 const std::string& function_name,
363 const std::vector<std::string>& options) {
364 std::vector<PyObject*> args;
365 std::unordered_map<std::string, PyObject*> kwargs;
366 std::string error_msg;
367 error_msg.reserve(2000);
368 error_msg += function_name;
369 error_msg += " received an invalid combination of arguments - ";
370
371 Py_ssize_t num_args = PyTuple_Size(given_args);
372 for (const auto i : c10::irange(num_args)) {
373 PyObject* arg = PyTuple_GET_ITEM(given_args, i);
374 args.push_back(arg);
375 }
376
377 bool has_kwargs = given_kwargs && PyDict_Size(given_kwargs) > 0;
378 if (has_kwargs) {
379 PyObject *key = nullptr, *value = nullptr;
380 Py_ssize_t pos = 0;
381
382 while (PyDict_Next(given_kwargs, &pos, &key, &value)) {
383 kwargs.emplace(THPUtils_unpackString(key), value);
384 }
385 }
386
387 if (options.size() == 1) {
388 auto pair = _parseOption(options[0], kwargs);
389 auto& option = pair.first;
390 auto& option_str = pair.second;
391 std::vector<std::string> unmatched_kwargs;
392 if (has_kwargs)
393 unmatched_kwargs = _tryMatchKwargs(option, kwargs);
394 if (!unmatched_kwargs.empty()) {
395 error_msg += "got unrecognized keyword arguments: ";
396 for (auto& kwarg : unmatched_kwargs)
397 error_msg += kwarg + ", ";
398 error_msg.erase(error_msg.length() - 2);
399 } else {
400 error_msg += "got ";
401 if (_argcountMatch(option, args, kwargs)) {
402 error_msg += _formattedArgDesc(option, args, kwargs);
403 } else {
404 error_msg += _argDesc(args, kwargs);
405 }
406 error_msg += ", but expected ";
407 error_msg += option_str;
408 }
409 } else {
410 error_msg += "got ";
411 error_msg += _argDesc(args, kwargs);
412 error_msg += ", but expected one of:\n";
413 for (auto& option_str : options) {
414 auto pair = _parseOption(option_str, kwargs);
415 auto& option = pair.first;
416 auto& printable_option_str = pair.second;
417 error_msg += " * ";
418 error_msg += printable_option_str;
419 error_msg += "\n";
420 if (_argcountMatch(option, args, kwargs)) {
421 std::vector<std::string> unmatched_kwargs;
422 if (has_kwargs)
423 unmatched_kwargs = _tryMatchKwargs(option, kwargs);
424 if (!unmatched_kwargs.empty()) {
425 error_msg +=
426 " didn't match because some of the keywords were incorrect: ";
427 for (auto& kwarg : unmatched_kwargs)
428 error_msg += kwarg + ", ";
429 error_msg.erase(error_msg.length() - 2);
430 error_msg += "\n";
431 } else {
432 error_msg +=
433 " didn't match because some of the arguments have invalid types: ";
434 error_msg += _formattedArgDesc(option, args, kwargs);
435 error_msg += "\n";
436 }
437 }
438 }
439 }
440 return error_msg;
441 }
442
443 } // namespace torch
444