1 #include <torch/csrc/utils/python_arg_parser.h>
2
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/Layout.h>
5 #include <torch/csrc/MemoryFormat.h>
6 #include <torch/csrc/autograd/python_variable.h>
7 #include <torch/csrc/utils/invalid_arguments.h>
8 #include <torch/csrc/utils/python_strings.h>
9 #include <torch/csrc/utils/python_torch_function_mode.h>
10 #include <torch/csrc/utils/torch_dispatch_mode.h>
11
12 #include <ATen/ATen.h>
13 #include <ATen/PythonTorchFunctionTLS.h>
14 #include <ATen/TracerMode.h>
15 #include <c10/util/irange.h>
16
17 #include <sstream>
18 #include <stdexcept>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22
23 namespace torch {
24
25 static std::unordered_map<std::string, ParameterType> type_map = {
26 {"Tensor", ParameterType::TENSOR},
27 {"Scalar", ParameterType::SCALAR},
28 {"int64_t", ParameterType::INT64},
29 {"SymInt", ParameterType::SYM_INT},
30 {"double", ParameterType::DOUBLE},
31 {"complex", ParameterType::COMPLEX},
32 {"TensorList", ParameterType::TENSOR_LIST},
33 {"c10::List<::std::optional<Tensor>>", ParameterType::TENSOR_LIST},
34 {"IntArrayRef", ParameterType::INT_LIST},
35 {"SymIntArrayRef", ParameterType::SYM_INT_LIST},
36 {"ArrayRef<double>", ParameterType::FLOAT_LIST},
37 {"Generator", ParameterType::GENERATOR},
38 {"bool", ParameterType::BOOL},
39 {"Storage", ParameterType::STORAGE},
40 {"PyObject*", ParameterType::PYOBJECT},
41 {"ScalarType", ParameterType::SCALARTYPE},
42 {"Layout", ParameterType::LAYOUT},
43 {"MemoryFormat", ParameterType::MEMORY_FORMAT},
44 {"QScheme", ParameterType::QSCHEME},
45 {"Device", ParameterType::DEVICE},
46 {"DeviceIndex", ParameterType::INT64},
47 {"Stream", ParameterType::STREAM},
48 {"std::string", ParameterType::STRING},
49 {"c10::string_view", ParameterType::STRING},
50 {"Dimname", ParameterType::DIMNAME},
51 {"DimnameList", ParameterType::DIMNAME_LIST},
52 {"ScalarList", ParameterType::SCALAR_LIST},
53 {"DispatchKeySet", ParameterType::DISPATCH_KEY_SET},
54 };
55
56 // Default arg name translations for compatibility with NumPy.
57 //
58 // Example:
59 // ```python
60 // t = torch.randn(10,10)
61 // torch.sum(a=t, axis=0, keepdim=True)
62 // ```
63 //
64 // A vector is necessary, because we might need to try multiple values.
65 // In particular, NumPy sometimes uses "x" and sometimes "a" for the main input
66 // tensor. Rather than annotate each function separately with whether it should
67 // take "x" or "a", just try both.
68 //
69 // TODO: Allow individual functions to specify non-default translations:
70 // For example, `torch.pow` should translate "exponent" to "x2".
71 static const std::unordered_map<std::string, std::vector<std::string>>
72 numpy_compatibility_arg_names = {
73 {"dim", {"axis"}},
74 {"keepdim", {"keepdims"}},
75 {"input", {"x", "a", "x1"}},
76 {"other", {"x2"}},
77 };
78
79 // TODO: remove this. This is a temporary list of functions that allow Python
80 // numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
81 // overloads and binding to the Tensor overload with a number of a different
82 // type will trigger a type error.
83 //
84 // If you modify this, you will need to adjust the blocklist in
85 // tools/pyi/gen_pyi.py (and add hardcoded signatures for these
86 // functions.)
should_allow_numbers_as_tensors(const std::string & name)87 bool should_allow_numbers_as_tensors(const std::string& name) {
88 static std::unordered_set<std::string> allowed = {
89 "add", "add_", "add_out",
90 "div", "div_", "div_out",
91 "divide", "divide_", "divide_out", // alias of div
92 "mul", "mul_", "mul_out",
93 "multiply", "multiply_", "multiply_out", // alias of mul
94 "sub", "sub_", "sub_out",
95 "subtract", "subtract_", "subtract_out", // alias of sub
96 "true_divide", "true_divide_", "true_divide_out",
97 "to", "_to_copy", "copy_",
98 "floor_divide", "floor_divide_", "floor_divide_out",
99 "_conj"}; // _conj needed because mul.Tensor backward calls it
100 return allowed.find(name) != allowed.end();
101 }
102
103 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FunctionParameter(const std::string & fmt,bool keyword_only)104 FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
105 : optional(false),
106 allow_none(false),
107 keyword_only(keyword_only),
108 size(0),
109 default_scalar(0) {
110 auto space = fmt.find(' ');
111 if (space == std::string::npos) {
112 throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
113 }
114
115 auto type_str = fmt.substr(0, space);
116
117 auto question = type_str.find('?');
118 if (question != std::string::npos) {
119 allow_none = true;
120 type_str = type_str.substr(0, question);
121 }
122
123 // Parse and remove brackets from type_str
124 auto bracket = type_str.find('[');
125 if (bracket != std::string::npos) {
126 auto size_str =
127 type_str.substr(bracket + 1, type_str.length() - bracket - 2);
128 size = atoi(size_str.c_str());
129 type_str = type_str.substr(0, bracket);
130 }
131
132 auto name_str = fmt.substr(space + 1);
133 auto it = type_map.find(type_str);
134 if (it == type_map.end()) {
135 throw std::runtime_error(
136 "FunctionParameter(): invalid type string: " + type_str);
137 }
138 type_ = it->second;
139
140 auto eq = name_str.find('=');
141 if (eq != std::string::npos) {
142 name = name_str.substr(0, eq);
143 optional = true;
144 set_default_str(name_str.substr(eq + 1));
145 } else {
146 name = name_str;
147 }
148 python_name = THPUtils_internString(name);
149 auto np_compat_it = numpy_compatibility_arg_names.find(name);
150 if (np_compat_it != numpy_compatibility_arg_names.end()) {
151 for (const auto& str : np_compat_it->second) {
152 numpy_python_names.push_back(THPUtils_internString(str));
153 }
154 }
155 }
156
handle_torch_function_getter(THPVariable * self,const std::string & property_name)157 auto handle_torch_function_getter(
158 THPVariable* self,
159 const std::string& property_name) -> PyObject* {
160 py::object torch_api = PyObject_FastGetAttrString(
161 THPVariableClass, (char*)property_name.c_str());
162 std::string module_name = "torch.Tensor." + property_name;
163 return handle_torch_function(
164 (PyObject*)self,
165 "__get__",
166 nullptr,
167 nullptr,
168 torch_api.ptr(),
169 module_name);
170 }
171
handle_torch_function_setter(THPVariable * self,const std::string & property_name,PyObject * value)172 auto handle_torch_function_setter(
173 THPVariable* self,
174 const std::string& property_name,
175 PyObject* value) -> int {
176 py::object torch_api = PyObject_FastGetAttrString(
177 THPVariableClass, (char*)property_name.c_str());
178 std::string module_name = "torch.Tensor." + property_name;
179 if (value != nullptr) {
180 py::tuple args_ = py::make_tuple(py::handle(value));
181 handle_torch_function(
182 (PyObject*)self,
183 "__set__",
184 args_.ptr(),
185 nullptr,
186 torch_api.ptr(),
187 module_name);
188 } else {
189 handle_torch_function(
190 (PyObject*)self,
191 "__delete__",
192 nullptr,
193 nullptr,
194 torch_api.ptr(),
195 module_name);
196 }
197 return 0;
198 }
199
200 // Combines self and args into one tuple.
combine_self_args(PyObject * self,PyObject * args)201 static auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple {
202 if (args == nullptr) {
203 return py::make_tuple(py::handle(self));
204 } else if (self == nullptr) {
205 return py::reinterpret_borrow<py::tuple>(args);
206 }
207
208 auto py_args = py::reinterpret_borrow<py::tuple>(args);
209 size_t n = py_args.size();
210 auto args_ = py::tuple(n + 1);
211 args_[0] = py::handle(self);
212 for (const auto i : c10::irange(n)) {
213 args_[i + 1] = py_args[i];
214 }
215 return args_;
216 }
217
218 // TODO: I'm not sure if I should call this __torch_function__ or
219 // torch_function. The former makes it easier to take an existing
220 // Tensor-like __torch_function__ object and turn it into a mode;
221 // but in general modes don't have to be Tensor-like (and we will
222 // improperly accept mode objects as arguments when they shouldn't
223 // be passed around in this way).
224 const char* torch_function_mode_name = "__torch_function__";
225
handle_torch_function(PyObject * self,const std::string & func_name,PyObject * args,PyObject * kwargs,PyObject * torch_api,const std::string & module_name)226 auto handle_torch_function(
227 PyObject* self,
228 const std::string& func_name,
229 PyObject* args,
230 PyObject* kwargs,
231 PyObject* torch_api,
232 const std::string& module_name) -> PyObject* {
233 py::object torch_api_function =
234 PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str());
235 TORCH_INTERNAL_ASSERT(
236 torch_api_function.ptr() != nullptr, "torch API function must exist");
237 py::tuple args_ = combine_self_args(self, args);
238 return handle_torch_function_no_python_arg_parser(
239 {self},
240 args_.ptr(),
241 kwargs,
242 func_name.c_str(),
243 torch_api_function.ptr(),
244 module_name.c_str(),
245 TorchFunctionName::TorchFunction);
246 }
247
248 // Note: [Overloaded args]
249 // An overloaded arg may be one of the following:
250 // - an instance of an object that has a __torch_function__ method
251 // - an instance of an object that has a __torch_dispatch__ classmethod
252 // - a class type that has a __torch_dispatch__ classmethod
253 //
254 // This function returns the type of the arg (if the arg is an instance),
255 // otherwise, it returns the arg.
get_type_of_overloaded_arg(PyObject * obj_or_type)256 static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
257 if (PyType_Check(obj_or_type)) {
258 return obj_or_type;
259 }
260 return (PyObject*)Py_TYPE(obj_or_type);
261 }
262
maybe_get_registered_torch_dispatch_rule(PyObject * torch_api_function,const py::object & torch_dispatch_object)263 static py::object maybe_get_registered_torch_dispatch_rule(
264 PyObject* torch_api_function,
265 const py::object& torch_dispatch_object) {
266 // This is a static object, so we must leak the Python object
267 // "release()" is used here to preserve 1 refcount on the
268 // object, preventing it from ever being de-allocated by CPython.
269 static const py::handle find_torch_dispatch_rule =
270 py::object(py::module_::import("torch._library.simple_registry")
271 .attr("find_torch_dispatch_rule"))
272 .release();
273 auto result = find_torch_dispatch_rule(
274 py::reinterpret_borrow<py::object>(torch_api_function),
275 torch_dispatch_object.get_type());
276 return result;
277 }
278
dispatch_on_subclass(PyObject * args,PyObject * kwargs,at::ArrayRef<PyObject * > overloaded_args,py::tuple py_types,PyObject * torch_api_function,bool is_torch_function,const char * torch_function_name_str,std::optional<c10::impl::TorchDispatchModeKey> maybe_mode_key=std::nullopt)279 static py::object dispatch_on_subclass(
280 PyObject* args,
281 PyObject* kwargs,
282 at::ArrayRef<PyObject*> overloaded_args,
283 py::tuple py_types,
284 PyObject* torch_api_function,
285 bool is_torch_function,
286 const char* torch_function_name_str,
287 std::optional<c10::impl::TorchDispatchModeKey> maybe_mode_key =
288 std::nullopt) {
289 py::object ret;
290 for (auto& arg : overloaded_args) {
291 py::object torch_function =
292 PyObject_FastGetAttrString(arg, torch_function_name_str);
293 if (!torch_function) {
294 TORCH_INTERNAL_ASSERT(0);
295 }
296 if (torch_function.ptr() == torch::disabled_torch_dispatch_impl()) {
297 // During __torch_dispatch__, don't dispatch on args with a disabled
298 // torch_dispatch. This code runs before infra modes, so we need to make
299 // sure that infra modes can run first. (In theory, maybe we can rearrange
300 // things so that infra modes are *always* attempted first, and just
301 // return NotImplemented when there are any user subclasses. Maybe that
302 // would fix this problem?)
303 continue;
304 }
305
306 // See https://github.com/pytorch/pytorch/issues/63767
307 if (is_torch_function &&
308 PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
309 .is(py::handle(arg)) &&
310 torch_function.ptr() != torch::disabled_torch_function_impl()) {
311 TORCH_WARN_ONCE(
312 "Defining your `",
313 torch_function_name_str,
314 "` as a plain method is deprecated ",
315 "and will be an error in future, please define it as a classmethod.");
316 }
317
318 if (!is_torch_function) {
319 auto maybe_torch_dispatch_rule = maybe_get_registered_torch_dispatch_rule(
320 torch_api_function, py::reinterpret_borrow<py::object>(arg));
321 if (!maybe_torch_dispatch_rule.is_none()) {
322 torch_function = maybe_torch_dispatch_rule;
323 auto py_arg = py::reinterpret_borrow<py::object>(arg);
324 ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
325 torch_function.ptr(),
326 py_arg.get_type().ptr(),
327 torch_api_function,
328 py_types.ptr(),
329 args,
330 kwargs,
331 NULL));
332 if (ret.ptr() == nullptr) {
333 throw python_error();
334 }
335 if (ret.ptr() != Py_NotImplemented) {
336 break;
337 }
338 }
339 }
340
341 ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
342 torch_function.ptr(),
343 torch_api_function,
344 py_types.ptr(),
345 args,
346 kwargs,
347 NULL));
348 if (ret.ptr() == nullptr) {
349 throw python_error();
350 }
351 if (ret.ptr() != Py_NotImplemented) {
352 // Return the reference to the result. This also covers the case where
353 // ret is NULL and __torch_function__/__torch_dispatch raised an
354 // exception, which we throw below
355 break;
356 }
357 }
358 return ret;
359 }
360
dispatch_on_mode(PyObject * args,PyObject * kwargs,py::tuple py_types,PyObject * torch_api_function,bool is_torch_function,const char * torch_function_name_str)361 static std::tuple<py::object, py::object> dispatch_on_mode(
362 PyObject* args,
363 PyObject* kwargs,
364 py::tuple py_types,
365 PyObject* torch_api_function,
366 bool is_torch_function,
367 const char* torch_function_name_str) {
368 // Disable mode on the inside; this makes for a more user-friendly
369 // experience if you try to, e.g., print your tensors.
370 std::optional<torch::overrides::StashTorchFunctionModeGuard> tf_g;
371 std::optional<torch_dispatch_mode::StashTorchDispatchModeGuard> td_g;
372 py::object mode_obj;
373 // NB: We only really need keep the mode_obj live if the function call
374 // fails for error reporting, but whatever, Python refcounts are cheap
375 if (is_torch_function) {
376 tf_g.emplace();
377 mode_obj = py::reinterpret_borrow<py::object>(
378 tf_g->get_cur_mode()->ptr(getPyInterpreter()));
379 } else {
380 td_g.emplace();
381 mode_obj = py::reinterpret_borrow<py::object>(
382 td_g->get_cur_mode()->ptr(getPyInterpreter()));
383 }
384 py::object torch_function =
385 PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str);
386 if (!torch_function) {
387 TORCH_INTERNAL_ASSERT(0);
388 }
389 TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr);
390 TORCH_INTERNAL_ASSERT(args != nullptr);
391
392 TORCH_CHECK(
393 PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(mode_obj),
394 "Defining your mode's `",
395 torch_function_name_str,
396 "` as a classmethod is not supported, please make it a plain method");
397
398 if (!is_torch_function) {
399 auto maybe_torch_dispatch_rule =
400 maybe_get_registered_torch_dispatch_rule(torch_api_function, mode_obj);
401 if (!maybe_torch_dispatch_rule.is_none()) {
402 auto ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
403 maybe_torch_dispatch_rule.ptr(),
404 mode_obj.ptr(),
405 torch_api_function,
406 py_types.ptr(),
407 args,
408 kwargs,
409 NULL));
410 if (ret.ptr() == nullptr) {
411 throw python_error();
412 }
413 return std::make_tuple(ret, mode_obj);
414 }
415 }
416
417 // Blegh. This accidentally works in PyObject_CallFunctionObjArgs below
418 // because the nullptr terminates the argument list ick ick ick.
419 py::object ret;
420 if (kwargs == nullptr) {
421 ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
422 mode_obj.ptr(),
423 torch_function_name_str,
424 "OOO",
425 torch_api_function,
426 py_types.ptr(),
427 args));
428 } else {
429 ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
430 mode_obj.ptr(),
431 torch_function_name_str,
432 "OOOO",
433 torch_api_function,
434 py_types.ptr(),
435 args,
436 kwargs));
437 }
438 if (ret.ptr() == nullptr) {
439 throw python_error();
440 }
441 return std::make_tuple(ret, mode_obj);
442 }
443
444 // See Note: [Overloaded args] for what they hold
handle_torch_function_no_python_arg_parser(at::ArrayRef<PyObject * > overloaded_args,PyObject * args,PyObject * kwargs,const char * func_name,PyObject * torch_api_function,const char * module_name,TorchFunctionName torch_function_name)445 auto handle_torch_function_no_python_arg_parser(
446 at::ArrayRef<PyObject*> overloaded_args,
447 PyObject* args,
448 PyObject* kwargs,
449 const char* func_name,
450 PyObject* torch_api_function,
451 const char* module_name,
452 TorchFunctionName torch_function_name) -> PyObject* {
453 const char* torch_function_name_str = nullptr;
454 switch (torch_function_name) {
455 case TorchFunctionName::TorchFunction:
456 torch_function_name_str = "__torch_function__";
457 break;
458 case TorchFunctionName::TorchDispatch:
459 torch_function_name_str = "__torch_dispatch__";
460 break;
461 default:
462 TORCH_INTERNAL_ASSERT(0, static_cast<int>(torch_function_name));
463 }
464 // overloaded_args already all have unique types
465 // nb: modes don't go in the overloaded types list, as they are not
466 // necessarily types
467 std::vector<py::object> overloaded_types;
468 overloaded_types.reserve(overloaded_args.size());
469 for (auto& arg : overloaded_args) {
470 overloaded_types.push_back(
471 py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg)));
472 }
473 py::tuple py_types = py::cast(overloaded_types);
474 py::object ret;
475 py::object mode_obj;
476
477 // Step 1: Try to dispatch based on the mode stack, *ignoring* infra
478 // torch_dispatch modes.
479 const bool is_torch_function =
480 torch_function_name == TorchFunctionName::TorchFunction;
481 const auto is_mode_active = [&]() {
482 return is_torch_function
483 ? at::impl::torch_function_mode_enabled()
484 // Check if any *user* torch_dispatch modes are active (not including
485 // fake and proxy modes, which are special)
486 : c10::impl::dispatch_mode_enabled();
487 };
488 // Note [__torch_dispatch__ dispatching order]
489 // The high-level idea motivating the dispatching
490 // order below is that: (1) modes get higher dispatch precedence over
491 // subclasses (2) "user" modes/subclasses get higher dispatch precedence over
492 // "infra" modes/subclasses.
493 //
494 // To give a complete example: let's say we are running torch.compile, with
495 // the following "user" modes and subclasses:
496 // mode_stack: [ModeA]
497 // user_args: [MyWrapperSubclassB(torchTensor)]
498
499 // During tracing in AOTAutograd tracing, we use some additional infra modes
500 // and subclasses to perform tracing:
501 // FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode,
502 // FunctionalTensor, FakeTensor
503 // The modified mode stack and tracing arguments will look like this:
504 // mode_stack (user modes): [ModeA]
505 // mode_stack (infra modes): [
506 // FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode
507 // ]
508 // tracing_args: [
509 // MyWrapperSubclassB(FunctionalTensor(_to_functional_tensor(FakeTensor)))
510 // ]
511
512 // And the dispatching order that we want is as follows:
513 // (1) ModeA.__torch_dispatch__ (user modes highest)
514 // (2) MyWrapperSubclassB.__torch_dispatch__ (user subclasses next highest)
515 // (3) FunctionalTensorMode.__torch_dispatch__ (infra modes next highest)
516 // (4) ProxyTorchDispatchMode.__torch_dispatch__ (infra modes next highest)
517 // (5) FakeTensorMode.__torch_dispatch__ (infra modes next highest)
518 // (6) FakeTensor.__torch_fake_dispatch__ (infra subclasses next highest)
519
520 // Why does do FunctionalTensor and FakeTensor even need to be special-cased
521 // in the ordering?
522 // In theory we could remove their __torch_dispatch__, but both of these
523 // subclasses override sizes/strides metadata calls with __torch_dispatch__,
524 // which would mean a mode would be **required** to access their metadata.
525
526 if (is_mode_active()) {
527 // Step 1: Try to dispatch on any user TorchDispatchModes (including infra
528 // modes, which will always be at the bottom of the mode stack).
529 auto ret_ = dispatch_on_mode(
530 args,
531 kwargs,
532 py_types,
533 torch_api_function,
534 is_torch_function,
535 torch_function_name_str);
536 ret = std::get<0>(ret_);
537 mode_obj = std::get<1>(ret_);
538 }
539
540 // Step 2: Try to dispatch based on any user subclasses,
541 // ignoring any subclasses that have a _mode_key field
542 // (corresponding to infra subclasses)
543 // Note: user subclasses should always run *before* infra modes like
544 // proxy/fake. This is handles by having proxy/fake modes return
545 // NotImplemented when they see a user subclass that they don't understand.
546 if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) {
547 auto curr_ret = dispatch_on_subclass(
548 args,
549 kwargs,
550 overloaded_args,
551 py_types,
552 torch_api_function,
553 is_torch_function,
554 torch_function_name_str);
555 if (curr_ret.ptr() != nullptr) {
556 ret = curr_ret;
557 }
558 }
559
560 if (ret.ptr() == nullptr) {
561 // if an exception occurred in a user's implementation of
562 // __torch_function__, throw it
563 throw python_error();
564 } else if (ret.ptr() == Py_NotImplemented) {
565 // all __torch_function__ implementations in overloaded_args
566 // returned NotImplemented, so we raise a TypeError.
567 std::stringstream ss;
568 ss << "Multiple dispatch failed for '";
569 if (module_name && func_name) {
570 ss << module_name << "." << func_name;
571 } else {
572 py::handle fn = torch_api_function;
573 ss << py::str(fn.attr("__module__")) << "."
574 << py::str(fn.attr("__name__"));
575 }
576 ss << "'; all " << torch_function_name_str
577 << " handlers returned NotImplemented:\n\n";
578 if (mode_obj) {
579 ss << " - mode object " << py::repr(mode_obj) << "\n";
580 }
581 for (auto& arg : overloaded_args) {
582 ss << " - tensor subclass " << py::repr(get_type_of_overloaded_arg(arg))
583 << "\n";
584 }
585 ss << "\nFor more information, try re-running with TORCH_LOGS=not_implemented";
586 const std::string& tmp = ss.str();
587 PyErr_SetString(PyExc_TypeError, tmp.c_str());
588 throw python_error();
589 }
590 return ret.release().ptr();
591 }
592
handle_torch_function(PythonArgs & r,PyObject * self,PyObject * args,PyObject * kwargs,PyObject * torch_api,const char * module_name,const char * func_name_override)593 auto handle_torch_function(
594 PythonArgs& r,
595 PyObject* self,
596 PyObject* args,
597 PyObject* kwargs,
598 PyObject* torch_api,
599 const char* module_name,
600 const char* func_name_override) -> PyObject* {
601 py::object torch_api_function = PyObject_FastGetAttrString(
602 torch_api,
603 (char*)(func_name_override ? func_name_override
604 : r.get_func_name().c_str()));
605 TORCH_INTERNAL_ASSERT(
606 torch_api_function.ptr() != nullptr, "torch API function must exist");
607 py::tuple args_ = combine_self_args(self, args);
608 return handle_torch_function_no_python_arg_parser(
609 r.overloaded_args,
610 args_.ptr(),
611 kwargs,
612 r.get_func_name().c_str(),
613 torch_api_function.ptr(),
614 module_name);
615 }
616
handle_torch_function(PythonArgs & r,PyObject * args,PyObject * kwargs,PyObject * torch_api,const char * module_name,const char * func_name_override)617 auto handle_torch_function(
618 PythonArgs& r,
619 PyObject* args,
620 PyObject* kwargs,
621 PyObject* torch_api,
622 const char* module_name,
623 const char* func_name_override) -> PyObject* {
624 return handle_torch_function(
625 r, nullptr, args, kwargs, torch_api, module_name, func_name_override);
626 }
627
handle_torch_function_indexing(PyObject * self,PyObject * index,PyObject * val)628 auto handle_torch_function_indexing(
629 PyObject* self,
630 PyObject* index,
631 PyObject* val) -> PyObject* {
632 const char* func_name = (val == nullptr) ? "__getitem__" : "__setitem__";
633 py::object index_tup;
634 if (PyTuple_Check(index)) {
635 index_tup = py::reinterpret_borrow<py::object>(index);
636 } else {
637 index_tup = py::make_tuple(py::handle(index));
638 }
639 std::vector<PyObject*> overridable_args;
640 is_tensor_and_append_overloaded(self, &overridable_args);
641 auto size = PyTuple_GET_SIZE(index_tup.ptr());
642 for (auto i : c10::irange(size)) {
643 auto* obj = PyTuple_GetItem(index_tup.ptr(), i);
644 is_tensor_and_append_overloaded(obj, &overridable_args);
645 }
646 if (val != nullptr) {
647 is_tensor_and_append_overloaded(val, &overridable_args);
648 }
649 py::object func =
650 PyObject_FastGetAttrString(THPVariableClass, (char*)func_name);
651 py::object args = (val == nullptr)
652 ? py::make_tuple(py::handle(self), py::handle(index))
653 : py::make_tuple(py::handle(self), py::handle(index), py::handle(val));
654 return handle_torch_function_no_python_arg_parser(
655 overridable_args,
656 args.ptr(),
657 nullptr,
658 func_name,
659 func.ptr(),
660 "torch.Tensor");
661 }
662
663 /*
664 * obj has a __torch_function__ implementation and may either be a
665 * subclass of Tensor or a Tensor-like duck type. We may need to
666 * append this object to the overloaded_args vector, which tracks all
667 * of the arguments with distinct __torch_function__ implementations
668 * we've seen so far.
669 *
670 * If this is the first argument we've seen with __torch_function__
671 * defined, we unconditionally add obj to the overloaded_args vector.
672 *
673 * If we've already seen arguments with __torch_function__ defined,
674 * then we first need to check if obj is the same type as any of the
675 * entries in overloaded_args. If so, we can ignore obj since we
676 * already have an entry in overloaded_args with the same
677 * __torch_function__ implementation.
678 *
679 * If it's a different type, we then need to check if it's a subclass
680 * of one of the types we've already seen. If so, we need to insert an
681 * entry in overloaded_args for this type with higher precedence than
682 * the superclass.
683 *
684 * See torch._overrides._get_overloaded_args for the equivalent
685 * function in the Python __torch_function__ implementation.
686 *
687 * The precedence-determining algorithm implemented in this function is
688 * described in NEP-0018:
689 * https://numpy.org/neps/nep-0018-array-function-protocol.html
690 *
691 * 'overloaded_args' is a raw pointer to a vector of pybind11 handles
692 * that have distinct __torch_function__ implementations, in order of calling
693 * precedence.
694 *
695 * 'obj' is an object to check for a __torch_function__ implementation
696 *
697 * If changing this file in a way that can affect the __torch_function__
698 * overhead, please report the benchmarks in 'benchmarks/overrides_benchmark'.
699 * See the instructions in the 'README.md' in that directory.
700 *
701 */
702
append_overloaded_arg(std::vector<PyObject * > * overloaded_args,PyObject * obj,bool obj_is_type)703 static void append_overloaded_arg(
704 std::vector<PyObject*>* overloaded_args,
705 PyObject* obj,
706 bool obj_is_type) {
707 bool class_not_seen_yet = true;
708 PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
709 for (auto& arg : *overloaded_args) {
710 if (obj_type == get_type_of_overloaded_arg(arg)) {
711 // obj is the same type as another parameter we've seen in a prior
712 // iteration of the loop over parameters so we already have an entry
713 // with the proper __torch_function__ implementation to call, so skip
714 // this parameter
715 class_not_seen_yet = false;
716 break;
717 }
718 }
719 if (class_not_seen_yet) {
720 auto arg_index = overloaded_args->size();
721 for (const auto j : c10::irange(arg_index)) {
722 if (PyObject_IsSubclass(
723 obj_type, get_type_of_overloaded_arg((*overloaded_args)[j]))) {
724 // obj is a subclass of another object we've seen already so its
725 // __torch_function__ should be called first, therefore we
726 // insert it into overloaded_args before the superclass
727 arg_index = j;
728 break;
729 }
730 }
731 // add object to overloaded_args. If it's a subclass of another class
732 // we've already seen it will be inserted before the superclass,
733 // otherwise it will be inserted at the end of the array
734 overloaded_args->insert(
735 overloaded_args->begin() + static_cast<long>(arg_index), obj);
736 }
737 }
738
append_overloaded_tensor(std::vector<PyObject * > * overloaded_args,PyObject * obj)739 void append_overloaded_tensor(
740 std::vector<PyObject*>* overloaded_args,
741 PyObject* obj) {
742 append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ false);
743 }
744
append_overloaded_type(std::vector<PyObject * > * overloaded_args,PyObject * obj)745 void append_overloaded_type(
746 std::vector<PyObject*>* overloaded_args,
747 PyObject* obj) {
748 append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ true);
749 }
750
is_tensor_and_append_overloaded(PyObject * obj,std::vector<PyObject * > * overloaded_args)751 bool is_tensor_and_append_overloaded(
752 PyObject* obj,
753 std::vector<PyObject*>* overloaded_args) {
754 if (THPVariable_CheckExact(obj)) {
755 // torch.Tensor instances (not subclasses, except for Parameter)
756 return true;
757 }
758
759 if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
760 // tensor subclasses and unrelated objects with __torch_function__
761 append_overloaded_tensor(overloaded_args, obj);
762 return true;
763 } else if (THPVariable_Check(obj)) {
764 // tensor subclasses without __torch_function__
765 return true;
766 }
767
768 return false;
769 }
770
is_scalar_list(PyObject * obj)771 static bool is_scalar_list(PyObject* obj) {
772 auto tuple = six::isTuple(obj);
773 if (!(tuple || PyList_Check(obj))) {
774 return false;
775 }
776 // NOLINTNEXTLINE(bugprone-branch-clone)
777 const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
778 for (const auto idx : c10::irange(size)) {
779 PyObject* iobj =
780 tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
781 if (!THPUtils_checkScalar(iobj)) {
782 return false;
783 }
784 }
785 return true;
786 }
787
is_tensor_list_and_append_overloaded(PyObject * obj,std::vector<PyObject * > * overloaded_args,size_t argnum,bool throw_error)788 bool is_tensor_list_and_append_overloaded(
789 PyObject* obj,
790 std::vector<PyObject*>* overloaded_args,
791 size_t argnum,
792 bool throw_error) {
793 auto tuple = six::isTuple(obj);
794 if (!(tuple || PyList_Check(obj))) {
795 return false;
796 }
797 // NOLINTNEXTLINE(bugprone-branch-clone)
798 const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
799 for (long idx = 0; idx < size; idx++) {
800 PyObject* iobj =
801 tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
802 if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) {
803 if (throw_error) {
804 TORCH_CHECK_TYPE(
805 false,
806 "expected Tensor as element ",
807 idx,
808 " in argument ",
809 argnum,
810 ", but got ",
811 Py_TYPE(iobj)->tp_name);
812 }
813 return false;
814 }
815 }
816 return true;
817 }
818
is_float_or_complex_list(PyObject * obj)819 static bool is_float_or_complex_list(PyObject* obj) {
820 auto tuple = six::isTuple(obj);
821 if (!(tuple || PyList_Check(obj))) {
822 return false;
823 }
824
825 // NOLINTNEXTLINE(bugprone-branch-clone)
826 const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
827 if (size > 0) {
828 PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
829 if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
830 return false;
831 }
832 }
833
834 return true;
835 }
836
is_int_or_symint(PyObject * obj)837 static bool is_int_or_symint(PyObject* obj) {
838 // THPUtils_checkIndex may call __index__ or __int__
839 // which may have side effects if obj is a symint node
840 // so we do `is_symint` check first
841 // TODO: maybe we should be using checkLong here?
842 if (torch::is_symint(py::handle(obj))) {
843 return true;
844 }
845
846 // FakeTensor(..., size=()) is qualified for SymInt param,
847 // but we can't go via __index__ (below) as we would normally
848 // do for regular tensors, because __index__ first forces a
849 // conversion into an int, which in general you cannot do
850 // if you have an unbacked SymInt. So this fastpath ensures
851 // that we still allow for fake tensors in this case, but
852 // for regular tensors it's redundant with the test below.
853 if (THPVariable_Check(obj)) {
854 auto& var = THPVariable_Unpack(obj);
855 if (TORCH_GUARD_SIZE_OBLIVIOUS(var.sym_numel().sym_eq(1)) &&
856 at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
857 return true;
858 }
859 }
860
861 if (THPUtils_checkIndex(obj)) {
862 return true;
863 }
864
865 return false;
866 }
867
is_int_or_symint_list(PyObject * obj,int broadcast_size,int64_t * failed_idx=nullptr)868 static bool is_int_or_symint_list(
869 PyObject* obj,
870 int broadcast_size,
871 int64_t* failed_idx = nullptr) {
872 if (PyTuple_Check(obj) || PyList_Check(obj)) {
873 if (PySequence_Size(obj) == 0) {
874 return true;
875 }
876 auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
877
878 if (is_int_or_symint(item.ptr())) {
879 return true;
880 }
881
882 // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
883 // in an intlist argument. Even float or complex scalar tensors.
884 bool r =
885 (jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
886 THPVariable_Unpack(item.ptr()).sizes().empty());
887 if (!r && failed_idx != nullptr) {
888 *failed_idx = 0;
889 }
890 return r;
891 }
892
893 // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
894 // int
895 return broadcast_size > 0 && is_int_or_symint(obj);
896 }
897
898 // argnum is needed for raising the TypeError, it's used in the error message.
check(PyObject * obj,std::vector<PyObject * > & overloaded_args,int argnum,int64_t * failed_idx)899 auto FunctionParameter::check(
900 PyObject* obj,
901 std::vector<PyObject*>& overloaded_args,
902 int argnum,
903 int64_t* failed_idx) -> bool {
904 switch (type_) {
905 case ParameterType::TENSOR: {
906 if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
907 return true;
908 }
909 if (allow_numbers_as_tensors) {
910 return THPUtils_checkScalar(obj);
911 }
912 return false;
913 }
914 case ParameterType::SCALAR:
915 if (THPUtils_checkScalar(obj)) {
916 return true;
917 }
918 [[fallthrough]];
919 case ParameterType::COMPLEX:
920 if (PyComplex_Check(obj)) {
921 return true;
922 }
923 [[fallthrough]];
924 case ParameterType::DOUBLE: {
925 if (THPUtils_checkDouble(obj)) {
926 return true;
927 }
928 if (THPVariable_Check(obj)) {
929 const auto& var = THPVariable_Unpack(obj);
930 return !var.requires_grad() && var.dim() == 0;
931 }
932 if (torch::is_symfloat(py::handle(obj)) ||
933 torch::is_symint(py::handle(obj))) {
934 // This will induce a guard
935 return true;
936 }
937 return false;
938 }
939 case ParameterType::INT64: {
940 if (THPUtils_checkLong(obj)) {
941 return true;
942 }
943 if (THPVariable_Check(obj)) {
944 const auto& var = THPVariable_Unpack(obj);
945 return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) &&
946 !var.requires_grad() && var.dim() == 0;
947 }
948 if (torch::is_symint(py::handle(obj))) {
949 // This will induce a guard
950 return true;
951 }
952 return false;
953 }
954 case ParameterType::DIMNAME:
955 return THPUtils_checkDimname(obj);
956 case ParameterType::DIMNAME_LIST: {
957 if (THPUtils_checkDimnameList(obj)) {
958 return true;
959 }
960 // if a size is specified (e.g. DimnameList[1]) we also allow passing a
961 // single Dimname
962 return size == 1 && THPUtils_checkDimname(obj);
963 }
964 case ParameterType::TENSOR_LIST: {
965 return is_tensor_list_and_append_overloaded(
966 obj, &overloaded_args, argnum, true /* throw_error */);
967 }
968 case ParameterType::FLOAT_LIST:
969 return is_float_or_complex_list(obj);
970 case ParameterType::GENERATOR:
971 return THPGenerator_Check(obj);
972 case ParameterType::BOOL:
973 return PyBool_Check(obj);
974 case ParameterType::STORAGE:
975 return isStorage(obj);
976 case ParameterType::PYOBJECT:
977 return true;
978 case ParameterType::SCALARTYPE:
979 return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
980 case ParameterType::LAYOUT:
981 return THPLayout_Check(obj);
982 case ParameterType::MEMORY_FORMAT:
983 return THPMemoryFormat_Check(obj);
984 case ParameterType::QSCHEME:
985 return THPQScheme_Check(obj);
986 case ParameterType::DEVICE:
987 // Allow symint to be passed in as device, but we'll specialize and
988 // guard in this case.
989 return THPUtils_checkLong(obj) || THPUtils_checkString(obj) ||
990 THPDevice_Check(obj) || torch::is_symint(py::handle(obj));
991 case ParameterType::STREAM:
992 return THPStream_Check(obj);
993 case ParameterType::STRING:
994 return THPUtils_checkString(obj);
995 case ParameterType::SCALAR_LIST:
996 return is_scalar_list(obj);
997 case ParameterType::SYM_INT:
998 return is_int_or_symint(obj);
999 // Allow SymInt where int is expected; we'll guard in this case
1000 case ParameterType::INT_LIST:
1001 case ParameterType::SYM_INT_LIST:
1002 return is_int_or_symint_list(obj, size, failed_idx);
1003 case ParameterType::DISPATCH_KEY_SET:
1004 return py::isinstance<c10::DispatchKeySet>(py::handle(obj));
1005 default:
1006 throw std::runtime_error("unknown parameter type");
1007 }
1008 }
1009
1010 // WARNING: these strings are parsed invalid_arguments.cpp
type_name() const1011 std::string FunctionParameter::type_name() const {
1012 switch (type_) {
1013 case ParameterType::TENSOR:
1014 return "Tensor";
1015 case ParameterType::SCALAR:
1016 return "Number";
1017 case ParameterType::INT64:
1018 // NB: SymInt is intentionally not mentioned here, as conventional user
1019 // use will only know about ints
1020 case ParameterType::SYM_INT:
1021 return "int";
1022 case ParameterType::DOUBLE:
1023 return "float";
1024 case ParameterType::COMPLEX:
1025 return "complex";
1026 case ParameterType::TENSOR_LIST:
1027 return "tuple of Tensors";
1028 case ParameterType::INT_LIST:
1029 return "tuple of ints";
1030 case ParameterType::FLOAT_LIST:
1031 return "tuple of floats";
1032 case ParameterType::GENERATOR:
1033 return "torch.Generator";
1034 case ParameterType::BOOL:
1035 return "bool";
1036 case ParameterType::STORAGE:
1037 return "torch.Storage";
1038 case ParameterType::PYOBJECT:
1039 return "object";
1040 case ParameterType::SCALARTYPE:
1041 return "torch.dtype";
1042 case ParameterType::LAYOUT:
1043 return "torch.layout";
1044 case ParameterType::MEMORY_FORMAT:
1045 return "torch.memory_format";
1046 case ParameterType::QSCHEME:
1047 return "torch.qscheme";
1048 case ParameterType::DEVICE:
1049 return "torch.device";
1050 case ParameterType::STRING:
1051 return "str";
1052 case ParameterType::DIMNAME:
1053 return "name";
1054 case ParameterType::DIMNAME_LIST:
1055 return "tuple of names";
1056 case ParameterType::SCALAR_LIST:
1057 return "tuple of Scalars";
1058 case ParameterType::SYM_INT_LIST:
1059 return "tuple of ints";
1060 case ParameterType::DISPATCH_KEY_SET:
1061 return "DispatchKeySet";
1062 default:
1063 throw std::runtime_error("unknown parameter type");
1064 }
1065 }
1066
parse_as_integer(const std::string & s)1067 static inline std::optional<int64_t> parse_as_integer(const std::string& s) {
1068 if (s.empty())
1069 return std::nullopt;
1070 char* str_end = nullptr;
1071 long ans = strtol(s.c_str(), &str_end, 0);
1072 // *str_end == 0 if the entire string was parsed as an integer.
1073 return (*str_end == 0) ? std::optional<int64_t>(ans) : std::nullopt;
1074 }
1075
1076 /*
1077 Parse default value of IntArrayRef declared at native_functions.yaml
1078
1079 There are two kinds of default values:
1080 1. IntArrayRef[2] x=1 (where size=2, value={1,1}
1081 2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be
1082 space after comma since native_parse.py uses ', ' to split args)
1083 */
parse_intlist_args(const std::string & s,int64_t size)1084 static inline std::vector<int64_t> parse_intlist_args(
1085 const std::string& s,
1086 int64_t size) {
1087 size_t n = s.size();
1088
1089 if (s.empty())
1090 return std::vector<int64_t>();
1091
1092 // case 1. s is an int (e.g., s=2)
1093 if (s[0] != '{') {
1094 TORCH_CHECK(size > 0, "Incorrect size of IntArrayRef: ", size);
1095 return std::vector<int64_t>(size, std::stol(s));
1096 }
1097
1098 // case 2. s is a list of dims (e.g., s={1,2})
1099
1100 // since already checked left brace '{' above, here only checks right brace
1101 // '}'
1102 TORCH_CHECK(
1103 s[n - 1] == '}',
1104 "Default value of IntArrayRef is missing right brace '}', found ",
1105 s[n - 1]);
1106
1107 auto args = std::vector<int64_t>();
1108 std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
1109 std::string tok;
1110
1111 while (std::getline(ss, tok, ',')) {
1112 args.emplace_back(std::stol(tok));
1113 }
1114 return args;
1115 }
1116
1117 // Parse a string literal to remove quotes and escape sequences
parse_string_literal(c10::string_view str)1118 static std::string parse_string_literal(c10::string_view str) {
1119 TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
1120
1121 if (str.front() == '"') {
1122 TORCH_CHECK(
1123 str.back() == '"', "Mismatched quotes in string default: ", str);
1124 } else {
1125 TORCH_CHECK(
1126 str.front() == '\'' && str.back() == '\'',
1127 "Invalid quotes in string default: ",
1128 str)
1129 }
1130
1131 std::string parsed;
1132 parsed.reserve(str.size());
1133 for (size_t i = 1; i < str.size() - 1;) {
1134 if (str[i] != '\\') {
1135 parsed.push_back(str[i]);
1136 ++i;
1137 continue;
1138 }
1139
1140 // Handle escape sequences
1141 TORCH_CHECK(
1142 i < str.size() - 2, "String ends with escaped final quote: ", str)
1143 char c = str[i + 1];
1144 switch (c) {
1145 case '\\':
1146 case '\'':
1147 case '\"':
1148 break;
1149 case 'a':
1150 c = '\a';
1151 break;
1152 case 'b':
1153 c = '\b';
1154 break;
1155 case 'f':
1156 c = '\f';
1157 break;
1158 case 'n':
1159 c = '\n';
1160 break;
1161 case 'v':
1162 c = '\v';
1163 break;
1164 case 't':
1165 c = '\t';
1166 break;
1167 default:
1168 TORCH_CHECK(
1169 false,
1170 "Unsupported escape sequence in string default: \\",
1171 str[i + 1]);
1172 }
1173 parsed.push_back(c);
1174 i += 2;
1175 }
1176 return parsed;
1177 }
1178
set_default_str(const std::string & str)1179 void FunctionParameter::set_default_str(const std::string& str) {
1180 if (str == "None") {
1181 allow_none = true;
1182 }
1183 if (type_ == ParameterType::TENSOR ||
1184 type_ == ParameterType::DISPATCH_KEY_SET) {
1185 if (str != "None") {
1186 throw std::runtime_error(
1187 "default value for Tensor must be none, got: " + str);
1188 }
1189 } else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
1190 default_int = atol(str.c_str());
1191 } else if (type_ == ParameterType::BOOL) {
1192 default_bool = (str == "True" || str == "true");
1193 } else if (type_ == ParameterType::DOUBLE) {
1194 default_double = atof(str.c_str());
1195 } else if (type_ == ParameterType::COMPLEX) {
1196 default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
1197 default_complex[1] = 0;
1198 } else if (type_ == ParameterType::SCALAR) {
1199 if (str != "None") {
1200 // we sometimes rely on integer-vs-float values, e.g. with arange.
1201 const auto as_integer = parse_as_integer(str);
1202 default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value())
1203 : at::Scalar(atof(str.c_str()));
1204 }
1205 } else if (
1206 type_ == ParameterType::INT_LIST ||
1207 type_ == ParameterType::SYM_INT_LIST) {
1208 if (str != "None") {
1209 default_intlist = parse_intlist_args(str, size);
1210 }
1211 } else if (type_ == ParameterType::FLOAT_LIST) {
1212 if (str != "None") {
1213 throw std::runtime_error("Defaults not supported for float[]");
1214 }
1215 } else if (type_ == ParameterType::SCALARTYPE) {
1216 if (str == "None") {
1217 default_scalartype = at::ScalarType::Undefined;
1218 } else if (str == "torch.int64") {
1219 default_scalartype = at::ScalarType::Long;
1220 } else {
1221 throw std::runtime_error("invalid default value for ScalarType: " + str);
1222 }
1223 } else if (type_ == ParameterType::LAYOUT) {
1224 if (str == "None") {
1225 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
1226 } else if (str == "torch.strided") {
1227 default_layout = at::Layout::Strided;
1228 } else if (str == "torch.sparse_coo") {
1229 default_layout = at::Layout::Sparse;
1230 } else {
1231 throw std::runtime_error("invalid default value for layout: " + str);
1232 }
1233 } else if (type_ == ParameterType::DEVICE) {
1234 if (str != "None") {
1235 throw std::runtime_error("invalid device: " + str);
1236 }
1237 } else if (type_ == ParameterType::STREAM) {
1238 if (str != "None") {
1239 throw std::runtime_error("invalid stream: " + str);
1240 }
1241 } else if (type_ == ParameterType::STRING) {
1242 if (str != "None") {
1243 default_string = parse_string_literal(str);
1244 }
1245 }
1246 // These types weren't handled here before. Adding a default error
1247 // led to a lot of test failures so adding this skip for now.
1248 // We should correctly handle these though because it might be causing
1249 // silent failures.
1250 else if (type_ == ParameterType::TENSOR_LIST) { // NOLINT
1251 // throw std::runtime_error("Invalid Tensor List");
1252 } else if (type_ == ParameterType::GENERATOR) { // NOLINT
1253 // throw std::runtime_error("ParameterType::GENERATOR");
1254 } else if (type_ == ParameterType::PYOBJECT) { // NOLINT
1255 // throw std::runtime_error("ParameterType::PYOBJECT");
1256 } else if (type_ == ParameterType::MEMORY_FORMAT) { // NOLINT
1257 // throw std::runtime_error("ParameterType::MEMORY_FORMAT");
1258 } else if (type_ == ParameterType::DIMNAME) { // NOLINT
1259 // throw std::runtime_error("ParameterType::DIMNAME");
1260 } else if (type_ == ParameterType::DIMNAME_LIST) { // NOLINT
1261 // throw std::runtime_error("ParameterType::DIMNAME_LIST");
1262 } else if (type_ == ParameterType::SCALAR_LIST) { // NOLINT
1263 // throw std::runtime_error("ParameterType::SCALAR_LIST");
1264 } else if (type_ == ParameterType::STORAGE) { // NOLINT
1265 // throw std::runtime_error("ParameterType::STORAGE");
1266 } else if (type_ == ParameterType::QSCHEME) { // NOLINT
1267 // throw std::runtime_error("ParameterType::QSCHEME");
1268 } else {
1269 throw std::runtime_error("unknown parameter type");
1270 }
1271 default_value = str;
1272 }
1273
1274 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FunctionSignature(const std::string & fmt,int index)1275 FunctionSignature::FunctionSignature(const std::string& fmt, int index)
1276 : min_args(0),
1277 max_args(0),
1278 max_pos_args(0),
1279 index(index),
1280 hidden(false),
1281 deprecated(false) {
1282 auto open_paren = fmt.find('(');
1283 if (open_paren == std::string::npos) {
1284 throw std::runtime_error("missing opening parenthesis: " + fmt);
1285 }
1286 name = fmt.substr(0, open_paren);
1287
1288 bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
1289
1290 auto last_offset = open_paren + 1;
1291 bool keyword_only = false;
1292 bool done = false;
1293 while (!done) {
1294 auto offset = fmt.find(", ", last_offset);
1295 auto next_offset = offset + 2;
1296 if (offset == std::string::npos) {
1297 offset = fmt.find(')', last_offset);
1298 done = true;
1299 next_offset = offset + 1;
1300 // this 'if' happens for an empty parameter list, i.e. fn().
1301 if (offset == last_offset) {
1302 last_offset = next_offset;
1303 break;
1304 }
1305 }
1306 if (offset == std::string::npos) {
1307 throw std::runtime_error("missing closing parenthesis: " + fmt);
1308 }
1309 if (offset == last_offset) {
1310 throw std::runtime_error("malformed signature: " + fmt);
1311 }
1312
1313 auto param_str = fmt.substr(last_offset, offset - last_offset);
1314 last_offset = next_offset;
1315 if (param_str == "*") {
1316 keyword_only = true;
1317 } else {
1318 params.emplace_back(param_str, keyword_only);
1319 params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
1320 }
1321 }
1322
1323 if (fmt.substr(last_offset) == "|deprecated") {
1324 hidden = true;
1325 // TODO: raise warning when parsing deprecated signatures
1326 deprecated = true;
1327 } else if (fmt.substr(last_offset) == "|hidden") {
1328 hidden = true;
1329 }
1330
1331 max_args = params.size();
1332
1333 // count the number of non-optional args
1334 for (auto& param : params) {
1335 if (!param.optional) {
1336 min_args++;
1337 }
1338 if (!param.keyword_only) {
1339 max_pos_args++;
1340 }
1341 }
1342 }
1343
toString() const1344 std::string FunctionSignature::toString() const {
1345 // optionals, etc.
1346 std::ostringstream ss;
1347 bool keyword_already = false;
1348 ss << "(";
1349 int i = 0;
1350 for (auto& param : params) {
1351 if (i != 0) {
1352 ss << ", ";
1353 }
1354 if (param.keyword_only && !keyword_already) {
1355 ss << "*, ";
1356 keyword_already = true;
1357 }
1358 ss << param.type_name() << " " << param.name;
1359 if (param.optional) {
1360 ss << " = " << param.default_value;
1361 }
1362 i++;
1363 }
1364 ss << ")";
1365 return ss.str();
1366 }
1367
extra_args(const FunctionSignature & signature,Py_ssize_t nargs)1368 [[noreturn]] static void extra_args(
1369 const FunctionSignature& signature,
1370 Py_ssize_t nargs) {
1371 const auto max_pos_args = signature.max_pos_args;
1372 const auto min_args = signature.min_args;
1373 const long nargs_ = nargs;
1374 if (min_args != max_pos_args) {
1375 throw TypeError(
1376 "%s() takes from %zu to %zu positional arguments but %ld were given",
1377 signature.name.c_str(),
1378 min_args,
1379 max_pos_args,
1380 nargs_);
1381 }
1382 throw TypeError(
1383 "%s() takes %zu positional argument%s but %ld %s given",
1384 signature.name.c_str(),
1385 max_pos_args,
1386 max_pos_args == 1 ? "" : "s",
1387 nargs_,
1388 nargs == 1 ? "was" : "were");
1389 }
1390
missing_args(const FunctionSignature & signature,int idx)1391 [[noreturn]] static void missing_args(
1392 const FunctionSignature& signature,
1393 int idx) {
1394 int num_missing = 0;
1395 std::stringstream ss;
1396
1397 auto& params = signature.params;
1398 for (auto it = params.begin() + idx; it != params.end(); ++it) {
1399 if (!it->optional) {
1400 if (num_missing > 0) {
1401 ss << ", ";
1402 }
1403 ss << '"' << it->name << '"';
1404 num_missing++;
1405 }
1406 }
1407
1408 throw TypeError(
1409 "%s() missing %d required positional argument%s: %s",
1410 signature.name.c_str(),
1411 num_missing,
1412 num_missing == 1 ? "s" : "",
1413 ss.str().c_str());
1414 }
1415
find_param(FunctionSignature & signature,PyObject * name)1416 static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) {
1417 Py_ssize_t i = 0;
1418 for (auto& param : signature.params) {
1419 int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ);
1420 if (cmp < 0) {
1421 throw python_error();
1422 } else if (cmp) {
1423 return i;
1424 }
1425 i++;
1426 }
1427 return -1;
1428 }
1429
extra_kwargs(FunctionSignature & signature,PyObject * kwargs,Py_ssize_t num_pos_args)1430 [[noreturn]] static void extra_kwargs(
1431 FunctionSignature& signature,
1432 PyObject* kwargs,
1433 Py_ssize_t num_pos_args) {
1434 PyObject* key = nullptr;
1435 PyObject* value = nullptr;
1436 Py_ssize_t pos = 0;
1437
1438 while (PyDict_Next(kwargs, &pos, &key, &value)) {
1439 if (!THPUtils_checkString(key)) {
1440 throw TypeError("keywords must be strings");
1441 }
1442
1443 auto param_idx = find_param(signature, key);
1444 if (param_idx < 0) {
1445 throw TypeError(
1446 "%s() got an unexpected keyword argument '%s'",
1447 signature.name.c_str(),
1448 THPUtils_unpackString(key).c_str());
1449 }
1450
1451 if (param_idx < num_pos_args) {
1452 throw TypeError(
1453 "%s() got multiple values for argument '%s'",
1454 signature.name.c_str(),
1455 THPUtils_unpackString(key).c_str());
1456 }
1457 }
1458
1459 // this should never be hit
1460 throw TypeError("invalid keyword arguments");
1461 }
1462
parse(PyObject * self,PyObject * args,PyObject * kwargs,PyObject * dst[],std::vector<PyObject * > & overloaded_args,bool raise_exception)1463 bool FunctionSignature::parse(
1464 PyObject* self,
1465 PyObject* args,
1466 PyObject* kwargs,
1467 PyObject* dst[], // NOLINT
1468 std::vector<PyObject*>& overloaded_args,
1469 bool raise_exception) {
1470 Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
1471 auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
1472 size_t arg_pos = 0;
1473 bool allow_varargs_intlist = false;
1474
1475 // if there is a single positional IntArrayRef argument, i.e. expand(..),
1476 // view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as
1477 // expand((5,3))
1478 if (max_pos_args == 1 &&
1479 (params[0].type_ == ParameterType::INT_LIST ||
1480 params[0].type_ == ParameterType::SYM_INT_LIST)) {
1481 allow_varargs_intlist = true;
1482 }
1483
1484 if (static_cast<size_t>(nargs) > max_pos_args && !allow_varargs_intlist) {
1485 if (raise_exception) {
1486 // foo() takes takes 2 positional arguments but 3 were given
1487 extra_args(*this, nargs);
1488 }
1489 return false;
1490 }
1491
1492 int i = 0;
1493 if (self != nullptr && check_has_torch_function(self, /*ignore_mode*/ true)) {
1494 append_overloaded_tensor(&overloaded_args, self);
1495 }
1496 for (auto& param : params) {
1497 PyObject* obj = nullptr;
1498 bool is_kwd = false;
1499 if (arg_pos < static_cast<size_t>(nargs)) {
1500 // extra positional args given after single positional IntArrayRef arg
1501 if (param.keyword_only) {
1502 if (raise_exception) {
1503 extra_args(*this, nargs);
1504 }
1505 return false;
1506 }
1507 obj = PyTuple_GET_ITEM(args, arg_pos);
1508 } else if (kwargs) {
1509 obj = PyDict_GetItem(kwargs, param.python_name);
1510 for (PyObject* numpy_name : param.numpy_python_names) {
1511 if (obj) {
1512 break;
1513 }
1514 obj = PyDict_GetItem(kwargs, numpy_name);
1515 }
1516 is_kwd = true;
1517 }
1518
1519 int64_t failed_idx = -1;
1520 bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd;
1521 if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
1522 dst[i++] = nullptr;
1523 } else if (!obj) {
1524 if (raise_exception) {
1525 // foo() missing 1 required positional argument: "b"
1526 missing_args(*this, i);
1527 }
1528 return false;
1529 } else if (param.check(obj, overloaded_args, i, &failed_idx)) {
1530 dst[i++] = obj;
1531 // XXX: the Variable check is necessary because sizes become tensors when
1532 // tracer is enabled. This behavior easily leads to ambiguities, and we
1533 // should avoid having complex signatures that make use of it...
1534 } else if (
1535 varargs_eligible &&
1536 (is_int_or_symint_list(args, param.size, &failed_idx))) {
1537 // take all positional arguments as this parameter
1538 // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
1539 dst[i++] = args;
1540 arg_pos = nargs;
1541 continue;
1542 } else if (raise_exception) {
1543 if (is_kwd) {
1544 // foo(): argument 'other' must be str, not int
1545 throw TypeError(
1546 "%s(): argument '%s' must be %s, not %s",
1547 name.c_str(),
1548 param.name.c_str(),
1549 param.type_name().c_str(),
1550 Py_TYPE(obj)->tp_name);
1551 } else {
1552 // foo(): argument 'other' (position 2) must be str, not int
1553 if (failed_idx != -1) {
1554 if (!(PyTuple_Check(obj) || PyList_Check(obj))) {
1555 TORCH_INTERNAL_ASSERT(varargs_eligible);
1556 obj = args;
1557 }
1558 TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj));
1559 throw TypeError(
1560 "%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld",
1561 name.c_str(),
1562 param.name.c_str(),
1563 static_cast<long>(arg_pos + 1),
1564 param.type_name().c_str(),
1565 Py_TYPE(py::reinterpret_steal<py::object>(
1566 PySequence_GetItem(obj, failed_idx))
1567 .ptr())
1568 ->tp_name,
1569 static_cast<long>(failed_idx));
1570 }
1571 throw TypeError(
1572 "%s(): argument '%s' (position %ld) must be %s, not %s",
1573 name.c_str(),
1574 param.name.c_str(),
1575 static_cast<long>(arg_pos + 1),
1576 param.type_name().c_str(),
1577 Py_TYPE(obj)->tp_name);
1578 }
1579 } else {
1580 return false;
1581 }
1582
1583 if (!is_kwd) {
1584 arg_pos++;
1585 } else if (obj) {
1586 remaining_kwargs--;
1587 }
1588 }
1589
1590 if (remaining_kwargs > 0) {
1591 if (raise_exception) {
1592 // foo() got an unexpected keyword argument "b"
1593 extra_kwargs(*this, kwargs, nargs);
1594 }
1595 return false;
1596 }
1597 return true;
1598 }
1599
PythonArgParser(const std::vector<std::string> & fmts,bool traceable)1600 PythonArgParser::PythonArgParser(
1601 const std::vector<std::string>& fmts,
1602 bool traceable)
1603 : max_args(0), traceable(traceable) {
1604 int index = 0;
1605 for (auto& fmt : fmts) {
1606 signatures_.emplace_back(fmt, index);
1607 ++index;
1608 }
1609 for (auto& signature : signatures_) {
1610 if (signature.max_args > max_args) {
1611 max_args = signature.max_args;
1612 }
1613 }
1614 if (!signatures_.empty()) {
1615 function_name = signatures_[0].name;
1616 }
1617
1618 // Check deprecated signatures last
1619 std::stable_partition(
1620 signatures_.begin(), signatures_.end(), [](const FunctionSignature& sig) {
1621 return !sig.deprecated;
1622 });
1623 }
1624
check_deprecated(const FunctionSignature & signature)1625 void PythonArgParser::check_deprecated(const FunctionSignature& signature) {
1626 if (signature.deprecated) {
1627 auto msg = c10::str(
1628 "This overload of ",
1629 signature.name,
1630 " is deprecated:\n\t",
1631 signature.name,
1632 signature.toString());
1633 auto signatures = get_signatures();
1634 if (!signatures.empty()) {
1635 msg += "\nConsider using one of the following signatures instead:";
1636 for (const auto& sig : signatures) {
1637 msg += "\n\t";
1638 msg += signature.name;
1639 msg += sig;
1640 }
1641 }
1642 TORCH_WARN_ONCE(msg);
1643 }
1644 }
1645
raw_parse(PyObject * self,PyObject * args,PyObject * kwargs,PyObject * parsed_args[])1646 PythonArgs PythonArgParser::raw_parse(
1647 PyObject* self,
1648 PyObject* args,
1649 PyObject* kwargs,
1650 PyObject* parsed_args[]) { // NOLINT
1651 if (signatures_.size() == 1) {
1652 auto& signature = signatures_[0];
1653 std::vector<PyObject*> overloaded_args;
1654 signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
1655 check_deprecated(signature);
1656 return PythonArgs(
1657 traceable, signature, parsed_args, std::move(overloaded_args));
1658 }
1659
1660 for (auto& signature : signatures_) {
1661 std::vector<PyObject*> overloaded_args;
1662 if (signature.parse(
1663 self, args, kwargs, parsed_args, overloaded_args, false)) {
1664 check_deprecated(signature);
1665 return PythonArgs(
1666 traceable, signature, parsed_args, std::move(overloaded_args));
1667 }
1668 }
1669
1670 print_error(self, args, kwargs, parsed_args);
1671 }
1672
print_error(PyObject * self,PyObject * args,PyObject * kwargs,PyObject * parsed_args[])1673 void PythonArgParser::print_error(
1674 PyObject* self,
1675 PyObject* args,
1676 PyObject* kwargs,
1677 PyObject* parsed_args[]) { // NOLINT
1678 size_t num_args =
1679 (args ? PyTuple_GET_SIZE(args) : 0) + (kwargs ? PyDict_Size(kwargs) : 0);
1680 std::vector<unsigned> plausible_idxs;
1681 unsigned i = 0;
1682 for (auto& signature : signatures_) {
1683 if (num_args >= signature.min_args && num_args <= signature.max_args &&
1684 !signature.hidden) {
1685 plausible_idxs.push_back(i);
1686 }
1687 i++;
1688 }
1689
1690 if (plausible_idxs.size() == 1) {
1691 auto& signature = signatures_[plausible_idxs[0]];
1692 std::vector<PyObject*> overloaded_args;
1693 signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
1694 }
1695
1696 auto options = get_signatures();
1697 auto msg =
1698 torch::format_invalid_args(args, kwargs, function_name + "()", options);
1699 throw TypeError("%s", msg.c_str());
1700 }
1701
get_signatures() const1702 std::vector<std::string> PythonArgParser::get_signatures() const {
1703 std::vector<std::string> options;
1704 for (auto& signature : signatures_) {
1705 if (!signature.hidden) {
1706 options.push_back(signature.toString());
1707 }
1708 }
1709 return options;
1710 }
1711
tensor_slow(int i)1712 at::Tensor PythonArgs::tensor_slow(int i) {
1713 PyObject* obj = args[i];
1714 if (!obj) {
1715 return at::Tensor();
1716 }
1717 if (THPVariable_Check(obj)) {
1718 return THPVariable_Unpack(obj);
1719 }
1720
1721 bool save_symint = false;
1722 at::Scalar scalar;
1723 if (PyBool_Check(obj)) {
1724 scalar = at::Scalar(THPUtils_unpackBool(obj));
1725 } else if (THPUtils_checkLong(obj)) {
1726 int overflow = -1;
1727 long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
1728 if (value == -1 && PyErr_Occurred()) {
1729 throw python_error();
1730 }
1731 if (overflow != 0) {
1732 // try unsigned
1733 unsigned long long value = PyLong_AsUnsignedLongLong(obj);
1734 if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1735 throw python_error();
1736 }
1737 scalar = at::Scalar(static_cast<uint64_t>(value));
1738 } else {
1739 scalar = at::Scalar(static_cast<int64_t>(value));
1740 }
1741 } else if (PyComplex_Check(obj)) {
1742 scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
1743 } else if (THPUtils_checkDouble(obj)) {
1744 scalar = at::Scalar(THPUtils_unpackDouble(obj));
1745 // NB: we DO NOT put symbolic ints/floats into the Scalar itself,
1746 // because although Scalar supports SymInt/SymFloat, the subsequent
1747 // conversion to Tensor does not. Instead, do it out of band.
1748 } else if (torch::is_symint(py::handle(obj))) {
1749 save_symint = true;
1750 // This scalar value doesn't matter, it shouldn't ever actually
1751 // get read out. Make it a big and weird looking number to help
1752 // people figure out if there's aproblem.
1753 scalar = at::Scalar(7777777);
1754 } else if (torch::is_symfloat(py::handle(obj))) {
1755 save_symint = true;
1756 scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
1757 } else if (torch::is_symbool(py::handle(obj))) {
1758 save_symint = true;
1759 scalar = at::Scalar(true);
1760 } else {
1761 // NB: Are you here because you passed None to a Variable method,
1762 // and you expected an undefined tensor to be returned? Don't add
1763 // a test for Py_None here; instead, you need to mark the argument
1764 // as *allowing none*; you can do this by writing 'Tensor?' instead
1765 // of 'Tensor' in the ATen metadata.
1766 throw TypeError(
1767 "expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name);
1768 }
1769 at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
1770 at::tracer::impl::NoTracerDispatchMode tracer_guard;
1771
1772 at::Tensor tensor = scalar_to_tensor(scalar);
1773 tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
1774
1775 if (save_symint) {
1776 auto py_tensor = py::cast(tensor);
1777 if (PyObject_SetAttrString(py_tensor.ptr(), "_wrapped_number", obj) < 0) {
1778 throw python_error();
1779 }
1780 }
1781
1782 return tensor;
1783 }
1784
scalar_slow(int i)1785 at::Scalar PythonArgs::scalar_slow(int i) {
1786 if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
1787 auto& var = THPVariable_Unpack(args[i]);
1788 jit::tracer::ArgumentStash::stashValue(
1789 signature.params[i].name, idx, var, c10::NumberType::get());
1790 }
1791
1792 return scalar_slow(args[i]);
1793 }
1794
scalar_slow(PyObject * arg)1795 at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
1796 // Zero-dim tensors are converted to Scalars as-is. Note this doesn't
1797 // currently handle most NumPy scalar types except np.float64.
1798 if (THPVariable_Check(arg)) {
1799 return THPVariable_Unpack(arg).item();
1800 }
1801
1802 if (THPUtils_checkLong(arg)) {
1803 int overflow = -1;
1804 long long value = PyLong_AsLongLongAndOverflow(arg, &overflow);
1805 if (value == -1 && PyErr_Occurred()) {
1806 throw python_error();
1807 }
1808 if (overflow != 0) {
1809 // try unsigned
1810 unsigned long long value = PyLong_AsUnsignedLongLong(arg);
1811 if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1812 throw python_error();
1813 }
1814 return at::Scalar(static_cast<uint64_t>(value));
1815 } else {
1816 return at::Scalar(static_cast<int64_t>(value));
1817 }
1818 }
1819
1820 if (PyBool_Check(arg)) {
1821 return at::Scalar(THPUtils_unpackBool(arg));
1822 }
1823
1824 if (PyComplex_Check(arg)) {
1825 return at::Scalar(THPUtils_unpackComplexDouble(arg));
1826 }
1827
1828 if (torch::is_symint(arg)) {
1829 return at::Scalar(py::cast<c10::SymInt>(arg));
1830 }
1831
1832 if (torch::is_symfloat(arg)) {
1833 return at::Scalar(py::cast<c10::SymFloat>(arg));
1834 }
1835
1836 if (torch::is_symbool(arg)) {
1837 // Windows build fails with C2440: '<function-style-cast>'
1838 // when at:Scalar(py::cast<c10::SymBool>(arg))
1839 auto sym_bool = py::handle(arg).cast<c10::SymBool>();
1840 return at::Scalar(sym_bool);
1841 }
1842
1843 return at::Scalar(THPUtils_unpackDouble(arg));
1844 }
1845
1846 } // namespace torch
1847