xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/python_arg_parser.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/utils/python_arg_parser.h>
2 
3 #include <torch/csrc/Exceptions.h>
4 #include <torch/csrc/Layout.h>
5 #include <torch/csrc/MemoryFormat.h>
6 #include <torch/csrc/autograd/python_variable.h>
7 #include <torch/csrc/utils/invalid_arguments.h>
8 #include <torch/csrc/utils/python_strings.h>
9 #include <torch/csrc/utils/python_torch_function_mode.h>
10 #include <torch/csrc/utils/torch_dispatch_mode.h>
11 
12 #include <ATen/ATen.h>
13 #include <ATen/PythonTorchFunctionTLS.h>
14 #include <ATen/TracerMode.h>
15 #include <c10/util/irange.h>
16 
17 #include <sstream>
18 #include <stdexcept>
19 #include <string>
20 #include <unordered_map>
21 #include <vector>
22 
23 namespace torch {
24 
25 static std::unordered_map<std::string, ParameterType> type_map = {
26     {"Tensor", ParameterType::TENSOR},
27     {"Scalar", ParameterType::SCALAR},
28     {"int64_t", ParameterType::INT64},
29     {"SymInt", ParameterType::SYM_INT},
30     {"double", ParameterType::DOUBLE},
31     {"complex", ParameterType::COMPLEX},
32     {"TensorList", ParameterType::TENSOR_LIST},
33     {"c10::List<::std::optional<Tensor>>", ParameterType::TENSOR_LIST},
34     {"IntArrayRef", ParameterType::INT_LIST},
35     {"SymIntArrayRef", ParameterType::SYM_INT_LIST},
36     {"ArrayRef<double>", ParameterType::FLOAT_LIST},
37     {"Generator", ParameterType::GENERATOR},
38     {"bool", ParameterType::BOOL},
39     {"Storage", ParameterType::STORAGE},
40     {"PyObject*", ParameterType::PYOBJECT},
41     {"ScalarType", ParameterType::SCALARTYPE},
42     {"Layout", ParameterType::LAYOUT},
43     {"MemoryFormat", ParameterType::MEMORY_FORMAT},
44     {"QScheme", ParameterType::QSCHEME},
45     {"Device", ParameterType::DEVICE},
46     {"DeviceIndex", ParameterType::INT64},
47     {"Stream", ParameterType::STREAM},
48     {"std::string", ParameterType::STRING},
49     {"c10::string_view", ParameterType::STRING},
50     {"Dimname", ParameterType::DIMNAME},
51     {"DimnameList", ParameterType::DIMNAME_LIST},
52     {"ScalarList", ParameterType::SCALAR_LIST},
53     {"DispatchKeySet", ParameterType::DISPATCH_KEY_SET},
54 };
55 
56 // Default arg name translations for compatibility with NumPy.
57 //
58 // Example:
59 // ```python
60 // t = torch.randn(10,10)
61 // torch.sum(a=t, axis=0, keepdim=True)
62 // ```
63 //
64 // A vector is necessary, because we might need to try multiple values.
65 // In particular, NumPy sometimes uses "x" and sometimes "a" for the main input
66 // tensor. Rather than annotate each function separately with whether it should
67 // take "x" or "a", just try both.
68 //
69 // TODO: Allow individual functions to specify non-default translations:
70 // For example, `torch.pow` should translate "exponent" to "x2".
71 static const std::unordered_map<std::string, std::vector<std::string>>
72     numpy_compatibility_arg_names = {
73         {"dim", {"axis"}},
74         {"keepdim", {"keepdims"}},
75         {"input", {"x", "a", "x1"}},
76         {"other", {"x2"}},
77 };
78 
79 // TODO: remove this. This is a temporary list of functions that allow Python
80 // numbers to bind to Tensors. Some binary ops have separate Tensor and Scalar
81 // overloads and binding to the Tensor overload with a number of a different
82 // type will trigger a type error.
83 //
84 // If you modify this, you will need to adjust the blocklist in
85 // tools/pyi/gen_pyi.py (and add hardcoded signatures for these
86 // functions.)
should_allow_numbers_as_tensors(const std::string & name)87 bool should_allow_numbers_as_tensors(const std::string& name) {
88   static std::unordered_set<std::string> allowed = {
89       "add",          "add_",          "add_out",
90       "div",          "div_",          "div_out",
91       "divide",       "divide_",       "divide_out", // alias of div
92       "mul",          "mul_",          "mul_out",
93       "multiply",     "multiply_",     "multiply_out", // alias of mul
94       "sub",          "sub_",          "sub_out",
95       "subtract",     "subtract_",     "subtract_out", // alias of sub
96       "true_divide",  "true_divide_",  "true_divide_out",
97       "to",           "_to_copy",      "copy_",
98       "floor_divide", "floor_divide_", "floor_divide_out",
99       "_conj"}; // _conj needed because mul.Tensor backward calls it
100   return allowed.find(name) != allowed.end();
101 }
102 
103 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FunctionParameter(const std::string & fmt,bool keyword_only)104 FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
105     : optional(false),
106       allow_none(false),
107       keyword_only(keyword_only),
108       size(0),
109       default_scalar(0) {
110   auto space = fmt.find(' ');
111   if (space == std::string::npos) {
112     throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
113   }
114 
115   auto type_str = fmt.substr(0, space);
116 
117   auto question = type_str.find('?');
118   if (question != std::string::npos) {
119     allow_none = true;
120     type_str = type_str.substr(0, question);
121   }
122 
123   // Parse and remove brackets from type_str
124   auto bracket = type_str.find('[');
125   if (bracket != std::string::npos) {
126     auto size_str =
127         type_str.substr(bracket + 1, type_str.length() - bracket - 2);
128     size = atoi(size_str.c_str());
129     type_str = type_str.substr(0, bracket);
130   }
131 
132   auto name_str = fmt.substr(space + 1);
133   auto it = type_map.find(type_str);
134   if (it == type_map.end()) {
135     throw std::runtime_error(
136         "FunctionParameter(): invalid type string: " + type_str);
137   }
138   type_ = it->second;
139 
140   auto eq = name_str.find('=');
141   if (eq != std::string::npos) {
142     name = name_str.substr(0, eq);
143     optional = true;
144     set_default_str(name_str.substr(eq + 1));
145   } else {
146     name = name_str;
147   }
148   python_name = THPUtils_internString(name);
149   auto np_compat_it = numpy_compatibility_arg_names.find(name);
150   if (np_compat_it != numpy_compatibility_arg_names.end()) {
151     for (const auto& str : np_compat_it->second) {
152       numpy_python_names.push_back(THPUtils_internString(str));
153     }
154   }
155 }
156 
handle_torch_function_getter(THPVariable * self,const std::string & property_name)157 auto handle_torch_function_getter(
158     THPVariable* self,
159     const std::string& property_name) -> PyObject* {
160   py::object torch_api = PyObject_FastGetAttrString(
161       THPVariableClass, (char*)property_name.c_str());
162   std::string module_name = "torch.Tensor." + property_name;
163   return handle_torch_function(
164       (PyObject*)self,
165       "__get__",
166       nullptr,
167       nullptr,
168       torch_api.ptr(),
169       module_name);
170 }
171 
handle_torch_function_setter(THPVariable * self,const std::string & property_name,PyObject * value)172 auto handle_torch_function_setter(
173     THPVariable* self,
174     const std::string& property_name,
175     PyObject* value) -> int {
176   py::object torch_api = PyObject_FastGetAttrString(
177       THPVariableClass, (char*)property_name.c_str());
178   std::string module_name = "torch.Tensor." + property_name;
179   if (value != nullptr) {
180     py::tuple args_ = py::make_tuple(py::handle(value));
181     handle_torch_function(
182         (PyObject*)self,
183         "__set__",
184         args_.ptr(),
185         nullptr,
186         torch_api.ptr(),
187         module_name);
188   } else {
189     handle_torch_function(
190         (PyObject*)self,
191         "__delete__",
192         nullptr,
193         nullptr,
194         torch_api.ptr(),
195         module_name);
196   }
197   return 0;
198 }
199 
200 // Combines self and args into one tuple.
combine_self_args(PyObject * self,PyObject * args)201 static auto combine_self_args(PyObject* self, PyObject* args) -> py::tuple {
202   if (args == nullptr) {
203     return py::make_tuple(py::handle(self));
204   } else if (self == nullptr) {
205     return py::reinterpret_borrow<py::tuple>(args);
206   }
207 
208   auto py_args = py::reinterpret_borrow<py::tuple>(args);
209   size_t n = py_args.size();
210   auto args_ = py::tuple(n + 1);
211   args_[0] = py::handle(self);
212   for (const auto i : c10::irange(n)) {
213     args_[i + 1] = py_args[i];
214   }
215   return args_;
216 }
217 
218 // TODO: I'm not sure if I should call this __torch_function__ or
219 // torch_function.  The former makes it easier to take an existing
220 // Tensor-like __torch_function__ object and turn it into a mode;
221 // but in general modes don't have to be Tensor-like (and we will
222 // improperly accept mode objects as arguments when they shouldn't
223 // be passed around in this way).
224 const char* torch_function_mode_name = "__torch_function__";
225 
handle_torch_function(PyObject * self,const std::string & func_name,PyObject * args,PyObject * kwargs,PyObject * torch_api,const std::string & module_name)226 auto handle_torch_function(
227     PyObject* self,
228     const std::string& func_name,
229     PyObject* args,
230     PyObject* kwargs,
231     PyObject* torch_api,
232     const std::string& module_name) -> PyObject* {
233   py::object torch_api_function =
234       PyObject_FastGetAttrString(torch_api, (char*)func_name.c_str());
235   TORCH_INTERNAL_ASSERT(
236       torch_api_function.ptr() != nullptr, "torch API function must exist");
237   py::tuple args_ = combine_self_args(self, args);
238   return handle_torch_function_no_python_arg_parser(
239       {self},
240       args_.ptr(),
241       kwargs,
242       func_name.c_str(),
243       torch_api_function.ptr(),
244       module_name.c_str(),
245       TorchFunctionName::TorchFunction);
246 }
247 
248 // Note: [Overloaded args]
249 // An overloaded arg may be one of the following:
250 // - an instance of an object that has a __torch_function__ method
251 // - an instance of an object that has a __torch_dispatch__ classmethod
252 // - a class type that has a __torch_dispatch__ classmethod
253 //
254 // This function returns the type of the arg (if the arg is an instance),
255 // otherwise, it returns the arg.
get_type_of_overloaded_arg(PyObject * obj_or_type)256 static PyObject* get_type_of_overloaded_arg(PyObject* obj_or_type) {
257   if (PyType_Check(obj_or_type)) {
258     return obj_or_type;
259   }
260   return (PyObject*)Py_TYPE(obj_or_type);
261 }
262 
maybe_get_registered_torch_dispatch_rule(PyObject * torch_api_function,const py::object & torch_dispatch_object)263 static py::object maybe_get_registered_torch_dispatch_rule(
264     PyObject* torch_api_function,
265     const py::object& torch_dispatch_object) {
266   // This is a static object, so we must leak the Python object
267   // "release()" is used here to preserve 1 refcount on the
268   // object, preventing it from ever being de-allocated by CPython.
269   static const py::handle find_torch_dispatch_rule =
270       py::object(py::module_::import("torch._library.simple_registry")
271                      .attr("find_torch_dispatch_rule"))
272           .release();
273   auto result = find_torch_dispatch_rule(
274       py::reinterpret_borrow<py::object>(torch_api_function),
275       torch_dispatch_object.get_type());
276   return result;
277 }
278 
dispatch_on_subclass(PyObject * args,PyObject * kwargs,at::ArrayRef<PyObject * > overloaded_args,py::tuple py_types,PyObject * torch_api_function,bool is_torch_function,const char * torch_function_name_str,std::optional<c10::impl::TorchDispatchModeKey> maybe_mode_key=std::nullopt)279 static py::object dispatch_on_subclass(
280     PyObject* args,
281     PyObject* kwargs,
282     at::ArrayRef<PyObject*> overloaded_args,
283     py::tuple py_types,
284     PyObject* torch_api_function,
285     bool is_torch_function,
286     const char* torch_function_name_str,
287     std::optional<c10::impl::TorchDispatchModeKey> maybe_mode_key =
288         std::nullopt) {
289   py::object ret;
290   for (auto& arg : overloaded_args) {
291     py::object torch_function =
292         PyObject_FastGetAttrString(arg, torch_function_name_str);
293     if (!torch_function) {
294       TORCH_INTERNAL_ASSERT(0);
295     }
296     if (torch_function.ptr() == torch::disabled_torch_dispatch_impl()) {
297       // During __torch_dispatch__, don't dispatch on args with a disabled
298       // torch_dispatch. This code runs before infra modes, so we need to make
299       // sure that infra modes can run first. (In theory, maybe we can rearrange
300       // things so that infra modes are *always* attempted first, and just
301       // return NotImplemented when there are any user subclasses. Maybe that
302       // would fix this problem?)
303       continue;
304     }
305 
306     // See https://github.com/pytorch/pytorch/issues/63767
307     if (is_torch_function &&
308         PyObject_FastGetAttrString(torch_function.ptr(), "__self__")
309             .is(py::handle(arg)) &&
310         torch_function.ptr() != torch::disabled_torch_function_impl()) {
311       TORCH_WARN_ONCE(
312           "Defining your `",
313           torch_function_name_str,
314           "` as a plain method is deprecated ",
315           "and will be an error in future, please define it as a classmethod.");
316     }
317 
318     if (!is_torch_function) {
319       auto maybe_torch_dispatch_rule = maybe_get_registered_torch_dispatch_rule(
320           torch_api_function, py::reinterpret_borrow<py::object>(arg));
321       if (!maybe_torch_dispatch_rule.is_none()) {
322         torch_function = maybe_torch_dispatch_rule;
323         auto py_arg = py::reinterpret_borrow<py::object>(arg);
324         ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
325             torch_function.ptr(),
326             py_arg.get_type().ptr(),
327             torch_api_function,
328             py_types.ptr(),
329             args,
330             kwargs,
331             NULL));
332         if (ret.ptr() == nullptr) {
333           throw python_error();
334         }
335         if (ret.ptr() != Py_NotImplemented) {
336           break;
337         }
338       }
339     }
340 
341     ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
342         torch_function.ptr(),
343         torch_api_function,
344         py_types.ptr(),
345         args,
346         kwargs,
347         NULL));
348     if (ret.ptr() == nullptr) {
349       throw python_error();
350     }
351     if (ret.ptr() != Py_NotImplemented) {
352       // Return the reference to the result. This also covers the case where
353       // ret is NULL and __torch_function__/__torch_dispatch raised an
354       // exception, which we throw below
355       break;
356     }
357   }
358   return ret;
359 }
360 
dispatch_on_mode(PyObject * args,PyObject * kwargs,py::tuple py_types,PyObject * torch_api_function,bool is_torch_function,const char * torch_function_name_str)361 static std::tuple<py::object, py::object> dispatch_on_mode(
362     PyObject* args,
363     PyObject* kwargs,
364     py::tuple py_types,
365     PyObject* torch_api_function,
366     bool is_torch_function,
367     const char* torch_function_name_str) {
368   // Disable mode on the inside; this makes for a more user-friendly
369   // experience if you try to, e.g., print your tensors.
370   std::optional<torch::overrides::StashTorchFunctionModeGuard> tf_g;
371   std::optional<torch_dispatch_mode::StashTorchDispatchModeGuard> td_g;
372   py::object mode_obj;
373   // NB: We only really need keep the mode_obj live if the function call
374   // fails for error reporting, but whatever, Python refcounts are cheap
375   if (is_torch_function) {
376     tf_g.emplace();
377     mode_obj = py::reinterpret_borrow<py::object>(
378         tf_g->get_cur_mode()->ptr(getPyInterpreter()));
379   } else {
380     td_g.emplace();
381     mode_obj = py::reinterpret_borrow<py::object>(
382         td_g->get_cur_mode()->ptr(getPyInterpreter()));
383   }
384   py::object torch_function =
385       PyObject_FastGetAttrString(mode_obj.ptr(), torch_function_name_str);
386   if (!torch_function) {
387     TORCH_INTERNAL_ASSERT(0);
388   }
389   TORCH_INTERNAL_ASSERT(py_types.ptr() != nullptr);
390   TORCH_INTERNAL_ASSERT(args != nullptr);
391 
392   TORCH_CHECK(
393       PyObject_FastGetAttrString(torch_function.ptr(), "__self__").is(mode_obj),
394       "Defining your mode's `",
395       torch_function_name_str,
396       "` as a classmethod is not supported, please make it a plain method");
397 
398   if (!is_torch_function) {
399     auto maybe_torch_dispatch_rule =
400         maybe_get_registered_torch_dispatch_rule(torch_api_function, mode_obj);
401     if (!maybe_torch_dispatch_rule.is_none()) {
402       auto ret = py::reinterpret_steal<py::object>(PyObject_CallFunctionObjArgs(
403           maybe_torch_dispatch_rule.ptr(),
404           mode_obj.ptr(),
405           torch_api_function,
406           py_types.ptr(),
407           args,
408           kwargs,
409           NULL));
410       if (ret.ptr() == nullptr) {
411         throw python_error();
412       }
413       return std::make_tuple(ret, mode_obj);
414     }
415   }
416 
417   // Blegh.  This accidentally works in PyObject_CallFunctionObjArgs below
418   // because the nullptr terminates the argument list ick ick ick.
419   py::object ret;
420   if (kwargs == nullptr) {
421     ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
422         mode_obj.ptr(),
423         torch_function_name_str,
424         "OOO",
425         torch_api_function,
426         py_types.ptr(),
427         args));
428   } else {
429     ret = py::reinterpret_steal<py::object>(PyObject_CallMethod(
430         mode_obj.ptr(),
431         torch_function_name_str,
432         "OOOO",
433         torch_api_function,
434         py_types.ptr(),
435         args,
436         kwargs));
437   }
438   if (ret.ptr() == nullptr) {
439     throw python_error();
440   }
441   return std::make_tuple(ret, mode_obj);
442 }
443 
444 // See Note: [Overloaded args] for what they hold
handle_torch_function_no_python_arg_parser(at::ArrayRef<PyObject * > overloaded_args,PyObject * args,PyObject * kwargs,const char * func_name,PyObject * torch_api_function,const char * module_name,TorchFunctionName torch_function_name)445 auto handle_torch_function_no_python_arg_parser(
446     at::ArrayRef<PyObject*> overloaded_args,
447     PyObject* args,
448     PyObject* kwargs,
449     const char* func_name,
450     PyObject* torch_api_function,
451     const char* module_name,
452     TorchFunctionName torch_function_name) -> PyObject* {
453   const char* torch_function_name_str = nullptr;
454   switch (torch_function_name) {
455     case TorchFunctionName::TorchFunction:
456       torch_function_name_str = "__torch_function__";
457       break;
458     case TorchFunctionName::TorchDispatch:
459       torch_function_name_str = "__torch_dispatch__";
460       break;
461     default:
462       TORCH_INTERNAL_ASSERT(0, static_cast<int>(torch_function_name));
463   }
464   // overloaded_args already all have unique types
465   // nb: modes don't go in the overloaded types list, as they are not
466   // necessarily types
467   std::vector<py::object> overloaded_types;
468   overloaded_types.reserve(overloaded_args.size());
469   for (auto& arg : overloaded_args) {
470     overloaded_types.push_back(
471         py::reinterpret_borrow<py::object>(get_type_of_overloaded_arg(arg)));
472   }
473   py::tuple py_types = py::cast(overloaded_types);
474   py::object ret;
475   py::object mode_obj;
476 
477   // Step 1: Try to dispatch based on the mode stack, *ignoring* infra
478   // torch_dispatch modes.
479   const bool is_torch_function =
480       torch_function_name == TorchFunctionName::TorchFunction;
481   const auto is_mode_active = [&]() {
482     return is_torch_function
483         ? at::impl::torch_function_mode_enabled()
484         // Check if any *user* torch_dispatch modes are active (not including
485         // fake and proxy modes, which are special)
486         : c10::impl::dispatch_mode_enabled();
487   };
488   // Note [__torch_dispatch__ dispatching order]
489   // The high-level idea motivating the dispatching
490   // order below is that: (1) modes get higher dispatch precedence over
491   // subclasses (2) "user" modes/subclasses get higher dispatch precedence over
492   // "infra" modes/subclasses.
493   //
494   // To give a complete example: let's say we are running torch.compile, with
495   // the following "user" modes and subclasses:
496   //   mode_stack: [ModeA]
497   //   user_args: [MyWrapperSubclassB(torchTensor)]
498 
499   // During tracing in AOTAutograd tracing, we use some additional infra modes
500   // and subclasses to perform tracing:
501   //   FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode,
502   //   FunctionalTensor, FakeTensor
503   // The modified mode stack and tracing arguments will look like this:
504   //   mode_stack (user modes): [ModeA]
505   //   mode_stack (infra modes): [
506   //     FunctionalTensorMode, ProxyTorchDispatchMode, FakeTensorMode
507   //   ]
508   //   tracing_args: [
509   //     MyWrapperSubclassB(FunctionalTensor(_to_functional_tensor(FakeTensor)))
510   //   ]
511 
512   // And the dispatching order that we want is as follows:
513   // (1) ModeA.__torch_dispatch__ (user modes highest)
514   // (2) MyWrapperSubclassB.__torch_dispatch__ (user subclasses next highest)
515   // (3) FunctionalTensorMode.__torch_dispatch__ (infra modes next highest)
516   // (4) ProxyTorchDispatchMode.__torch_dispatch__ (infra modes next highest)
517   // (5) FakeTensorMode.__torch_dispatch__ (infra modes next highest)
518   // (6) FakeTensor.__torch_fake_dispatch__ (infra subclasses next highest)
519 
520   // Why does do FunctionalTensor and FakeTensor even need to be special-cased
521   // in the ordering?
522   // In theory we could remove their __torch_dispatch__, but both of these
523   // subclasses override sizes/strides metadata calls with __torch_dispatch__,
524   // which would mean a mode would be **required** to access their metadata.
525 
526   if (is_mode_active()) {
527     // Step 1: Try to dispatch on any user TorchDispatchModes (including infra
528     // modes, which will always be at the bottom of the mode stack).
529     auto ret_ = dispatch_on_mode(
530         args,
531         kwargs,
532         py_types,
533         torch_api_function,
534         is_torch_function,
535         torch_function_name_str);
536     ret = std::get<0>(ret_);
537     mode_obj = std::get<1>(ret_);
538   }
539 
540   // Step 2: Try to dispatch based on any user subclasses,
541   // ignoring any subclasses that have a _mode_key field
542   // (corresponding to infra subclasses)
543   // Note: user subclasses should always run *before* infra modes like
544   // proxy/fake. This is handles by having proxy/fake modes return
545   // NotImplemented when they see a user subclass that they don't understand.
546   if (ret.ptr() == nullptr || ret.ptr() == Py_NotImplemented) {
547     auto curr_ret = dispatch_on_subclass(
548         args,
549         kwargs,
550         overloaded_args,
551         py_types,
552         torch_api_function,
553         is_torch_function,
554         torch_function_name_str);
555     if (curr_ret.ptr() != nullptr) {
556       ret = curr_ret;
557     }
558   }
559 
560   if (ret.ptr() == nullptr) {
561     // if an exception occurred in a user's implementation of
562     // __torch_function__, throw it
563     throw python_error();
564   } else if (ret.ptr() == Py_NotImplemented) {
565     // all __torch_function__ implementations in overloaded_args
566     // returned NotImplemented, so we raise a TypeError.
567     std::stringstream ss;
568     ss << "Multiple dispatch failed for '";
569     if (module_name && func_name) {
570       ss << module_name << "." << func_name;
571     } else {
572       py::handle fn = torch_api_function;
573       ss << py::str(fn.attr("__module__")) << "."
574          << py::str(fn.attr("__name__"));
575     }
576     ss << "'; all " << torch_function_name_str
577        << " handlers returned NotImplemented:\n\n";
578     if (mode_obj) {
579       ss << "  - mode object " << py::repr(mode_obj) << "\n";
580     }
581     for (auto& arg : overloaded_args) {
582       ss << "  - tensor subclass " << py::repr(get_type_of_overloaded_arg(arg))
583          << "\n";
584     }
585     ss << "\nFor more information, try re-running with TORCH_LOGS=not_implemented";
586     const std::string& tmp = ss.str();
587     PyErr_SetString(PyExc_TypeError, tmp.c_str());
588     throw python_error();
589   }
590   return ret.release().ptr();
591 }
592 
handle_torch_function(PythonArgs & r,PyObject * self,PyObject * args,PyObject * kwargs,PyObject * torch_api,const char * module_name,const char * func_name_override)593 auto handle_torch_function(
594     PythonArgs& r,
595     PyObject* self,
596     PyObject* args,
597     PyObject* kwargs,
598     PyObject* torch_api,
599     const char* module_name,
600     const char* func_name_override) -> PyObject* {
601   py::object torch_api_function = PyObject_FastGetAttrString(
602       torch_api,
603       (char*)(func_name_override ? func_name_override
604                                  : r.get_func_name().c_str()));
605   TORCH_INTERNAL_ASSERT(
606       torch_api_function.ptr() != nullptr, "torch API function must exist");
607   py::tuple args_ = combine_self_args(self, args);
608   return handle_torch_function_no_python_arg_parser(
609       r.overloaded_args,
610       args_.ptr(),
611       kwargs,
612       r.get_func_name().c_str(),
613       torch_api_function.ptr(),
614       module_name);
615 }
616 
handle_torch_function(PythonArgs & r,PyObject * args,PyObject * kwargs,PyObject * torch_api,const char * module_name,const char * func_name_override)617 auto handle_torch_function(
618     PythonArgs& r,
619     PyObject* args,
620     PyObject* kwargs,
621     PyObject* torch_api,
622     const char* module_name,
623     const char* func_name_override) -> PyObject* {
624   return handle_torch_function(
625       r, nullptr, args, kwargs, torch_api, module_name, func_name_override);
626 }
627 
handle_torch_function_indexing(PyObject * self,PyObject * index,PyObject * val)628 auto handle_torch_function_indexing(
629     PyObject* self,
630     PyObject* index,
631     PyObject* val) -> PyObject* {
632   const char* func_name = (val == nullptr) ? "__getitem__" : "__setitem__";
633   py::object index_tup;
634   if (PyTuple_Check(index)) {
635     index_tup = py::reinterpret_borrow<py::object>(index);
636   } else {
637     index_tup = py::make_tuple(py::handle(index));
638   }
639   std::vector<PyObject*> overridable_args;
640   is_tensor_and_append_overloaded(self, &overridable_args);
641   auto size = PyTuple_GET_SIZE(index_tup.ptr());
642   for (auto i : c10::irange(size)) {
643     auto* obj = PyTuple_GetItem(index_tup.ptr(), i);
644     is_tensor_and_append_overloaded(obj, &overridable_args);
645   }
646   if (val != nullptr) {
647     is_tensor_and_append_overloaded(val, &overridable_args);
648   }
649   py::object func =
650       PyObject_FastGetAttrString(THPVariableClass, (char*)func_name);
651   py::object args = (val == nullptr)
652       ? py::make_tuple(py::handle(self), py::handle(index))
653       : py::make_tuple(py::handle(self), py::handle(index), py::handle(val));
654   return handle_torch_function_no_python_arg_parser(
655       overridable_args,
656       args.ptr(),
657       nullptr,
658       func_name,
659       func.ptr(),
660       "torch.Tensor");
661 }
662 
663 /*
664  *  obj has a __torch_function__ implementation and may either be a
665  *  subclass of Tensor or a Tensor-like duck type. We may need to
666  *  append this object to the overloaded_args vector, which tracks all
667  *  of the arguments with distinct __torch_function__ implementations
668  *  we've seen so far.
669  *
670  *  If this is the first argument we've seen with __torch_function__
671  *  defined, we unconditionally add obj to the overloaded_args vector.
672  *
673  *  If we've already seen arguments with __torch_function__ defined,
674  *  then we first need to check if obj is the same type as any of the
675  *  entries in overloaded_args.  If so, we can ignore obj since we
676  *  already have an entry in overloaded_args with the same
677  *  __torch_function__ implementation.
678  *
679  *  If it's a different type, we then need to check if it's a subclass
680  *  of one of the types we've already seen. If so, we need to insert an
681  *  entry in overloaded_args for this type with higher precedence than
682  *  the superclass.
683  *
684  *  See torch._overrides._get_overloaded_args for the equivalent
685  *  function in the Python __torch_function__ implementation.
686  *
687  *  The precedence-determining algorithm implemented in this function is
688  *  described in NEP-0018:
689  *  https://numpy.org/neps/nep-0018-array-function-protocol.html
690  *
691  *  'overloaded_args' is a raw pointer to a vector of pybind11 handles
692  *  that have distinct __torch_function__ implementations, in order of calling
693  *  precedence.
694  *
695  *  'obj' is an object to check for a __torch_function__ implementation
696  *
697  * If changing this file in a way that can affect the __torch_function__
698  * overhead, please report the benchmarks in 'benchmarks/overrides_benchmark'.
699  * See the instructions in the 'README.md' in that directory.
700  *
701  */
702 
append_overloaded_arg(std::vector<PyObject * > * overloaded_args,PyObject * obj,bool obj_is_type)703 static void append_overloaded_arg(
704     std::vector<PyObject*>* overloaded_args,
705     PyObject* obj,
706     bool obj_is_type) {
707   bool class_not_seen_yet = true;
708   PyObject* obj_type = obj_is_type ? obj : (PyObject*)Py_TYPE(obj);
709   for (auto& arg : *overloaded_args) {
710     if (obj_type == get_type_of_overloaded_arg(arg)) {
711       // obj is the same type as another parameter we've seen in a prior
712       // iteration of the loop over parameters so we already have an entry
713       // with the proper __torch_function__ implementation to call, so skip
714       // this parameter
715       class_not_seen_yet = false;
716       break;
717     }
718   }
719   if (class_not_seen_yet) {
720     auto arg_index = overloaded_args->size();
721     for (const auto j : c10::irange(arg_index)) {
722       if (PyObject_IsSubclass(
723               obj_type, get_type_of_overloaded_arg((*overloaded_args)[j]))) {
724         // obj is a subclass of another object we've seen already so its
725         // __torch_function__ should be called first, therefore we
726         // insert it into overloaded_args before the superclass
727         arg_index = j;
728         break;
729       }
730     }
731     // add object to overloaded_args. If it's a subclass of another class
732     // we've already seen it will be inserted before the superclass,
733     // otherwise it will be inserted at the end of the array
734     overloaded_args->insert(
735         overloaded_args->begin() + static_cast<long>(arg_index), obj);
736   }
737 }
738 
append_overloaded_tensor(std::vector<PyObject * > * overloaded_args,PyObject * obj)739 void append_overloaded_tensor(
740     std::vector<PyObject*>* overloaded_args,
741     PyObject* obj) {
742   append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ false);
743 }
744 
append_overloaded_type(std::vector<PyObject * > * overloaded_args,PyObject * obj)745 void append_overloaded_type(
746     std::vector<PyObject*>* overloaded_args,
747     PyObject* obj) {
748   append_overloaded_arg(overloaded_args, obj, /*obj_is_type*/ true);
749 }
750 
is_tensor_and_append_overloaded(PyObject * obj,std::vector<PyObject * > * overloaded_args)751 bool is_tensor_and_append_overloaded(
752     PyObject* obj,
753     std::vector<PyObject*>* overloaded_args) {
754   if (THPVariable_CheckExact(obj)) {
755     // torch.Tensor instances (not subclasses, except for Parameter)
756     return true;
757   }
758 
759   if (check_has_torch_function(obj, /*ignore_mode*/ true)) {
760     // tensor subclasses and unrelated objects with __torch_function__
761     append_overloaded_tensor(overloaded_args, obj);
762     return true;
763   } else if (THPVariable_Check(obj)) {
764     // tensor subclasses without __torch_function__
765     return true;
766   }
767 
768   return false;
769 }
770 
is_scalar_list(PyObject * obj)771 static bool is_scalar_list(PyObject* obj) {
772   auto tuple = six::isTuple(obj);
773   if (!(tuple || PyList_Check(obj))) {
774     return false;
775   }
776   // NOLINTNEXTLINE(bugprone-branch-clone)
777   const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
778   for (const auto idx : c10::irange(size)) {
779     PyObject* iobj =
780         tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
781     if (!THPUtils_checkScalar(iobj)) {
782       return false;
783     }
784   }
785   return true;
786 }
787 
is_tensor_list_and_append_overloaded(PyObject * obj,std::vector<PyObject * > * overloaded_args,size_t argnum,bool throw_error)788 bool is_tensor_list_and_append_overloaded(
789     PyObject* obj,
790     std::vector<PyObject*>* overloaded_args,
791     size_t argnum,
792     bool throw_error) {
793   auto tuple = six::isTuple(obj);
794   if (!(tuple || PyList_Check(obj))) {
795     return false;
796   }
797   // NOLINTNEXTLINE(bugprone-branch-clone)
798   const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
799   for (long idx = 0; idx < size; idx++) {
800     PyObject* iobj =
801         tuple ? PyTuple_GET_ITEM(obj, idx) : PyList_GET_ITEM(obj, idx);
802     if (!is_tensor_and_append_overloaded(iobj, overloaded_args)) {
803       if (throw_error) {
804         TORCH_CHECK_TYPE(
805             false,
806             "expected Tensor as element ",
807             idx,
808             " in argument ",
809             argnum,
810             ", but got ",
811             Py_TYPE(iobj)->tp_name);
812       }
813       return false;
814     }
815   }
816   return true;
817 }
818 
is_float_or_complex_list(PyObject * obj)819 static bool is_float_or_complex_list(PyObject* obj) {
820   auto tuple = six::isTuple(obj);
821   if (!(tuple || PyList_Check(obj))) {
822     return false;
823   }
824 
825   // NOLINTNEXTLINE(bugprone-branch-clone)
826   const auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
827   if (size > 0) {
828     PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
829     if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
830       return false;
831     }
832   }
833 
834   return true;
835 }
836 
is_int_or_symint(PyObject * obj)837 static bool is_int_or_symint(PyObject* obj) {
838   // THPUtils_checkIndex may call __index__ or __int__
839   // which may have side effects if obj is a symint node
840   // so we do `is_symint` check first
841   // TODO: maybe we should be using checkLong here?
842   if (torch::is_symint(py::handle(obj))) {
843     return true;
844   }
845 
846   // FakeTensor(..., size=()) is qualified for SymInt param,
847   // but we can't go via __index__ (below) as we would normally
848   // do for regular tensors, because __index__ first forces a
849   // conversion into an int, which in general you cannot do
850   // if you have an unbacked SymInt.  So this fastpath ensures
851   // that we still allow for fake tensors in this case, but
852   // for regular tensors it's redundant with the test below.
853   if (THPVariable_Check(obj)) {
854     auto& var = THPVariable_Unpack(obj);
855     if (TORCH_GUARD_SIZE_OBLIVIOUS(var.sym_numel().sym_eq(1)) &&
856         at::isIntegralType(var.dtype().toScalarType(), /*include_bool*/ true)) {
857       return true;
858     }
859   }
860 
861   if (THPUtils_checkIndex(obj)) {
862     return true;
863   }
864 
865   return false;
866 }
867 
is_int_or_symint_list(PyObject * obj,int broadcast_size,int64_t * failed_idx=nullptr)868 static bool is_int_or_symint_list(
869     PyObject* obj,
870     int broadcast_size,
871     int64_t* failed_idx = nullptr) {
872   if (PyTuple_Check(obj) || PyList_Check(obj)) {
873     if (PySequence_Size(obj) == 0) {
874       return true;
875     }
876     auto item = py::reinterpret_steal<py::object>(PySequence_GetItem(obj, 0));
877 
878     if (is_int_or_symint(item.ptr())) {
879       return true;
880     }
881 
882     // NOTE: JIT tracer allows arbitrary scalar tensors to act as ints
883     // in an intlist argument. Even float or complex scalar tensors.
884     bool r =
885         (jit::tracer::isTracing() && THPVariable_Check(item.ptr()) &&
886          THPVariable_Unpack(item.ptr()).sizes().empty());
887     if (!r && failed_idx != nullptr) {
888       *failed_idx = 0;
889     }
890     return r;
891   }
892 
893   // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single
894   // int
895   return broadcast_size > 0 && is_int_or_symint(obj);
896 }
897 
898 // argnum is needed for raising the TypeError, it's used in the error message.
check(PyObject * obj,std::vector<PyObject * > & overloaded_args,int argnum,int64_t * failed_idx)899 auto FunctionParameter::check(
900     PyObject* obj,
901     std::vector<PyObject*>& overloaded_args,
902     int argnum,
903     int64_t* failed_idx) -> bool {
904   switch (type_) {
905     case ParameterType::TENSOR: {
906       if (is_tensor_and_append_overloaded(obj, &overloaded_args)) {
907         return true;
908       }
909       if (allow_numbers_as_tensors) {
910         return THPUtils_checkScalar(obj);
911       }
912       return false;
913     }
914     case ParameterType::SCALAR:
915       if (THPUtils_checkScalar(obj)) {
916         return true;
917       }
918       [[fallthrough]];
919     case ParameterType::COMPLEX:
920       if (PyComplex_Check(obj)) {
921         return true;
922       }
923       [[fallthrough]];
924     case ParameterType::DOUBLE: {
925       if (THPUtils_checkDouble(obj)) {
926         return true;
927       }
928       if (THPVariable_Check(obj)) {
929         const auto& var = THPVariable_Unpack(obj);
930         return !var.requires_grad() && var.dim() == 0;
931       }
932       if (torch::is_symfloat(py::handle(obj)) ||
933           torch::is_symint(py::handle(obj))) {
934         // This will induce a guard
935         return true;
936       }
937       return false;
938     }
939     case ParameterType::INT64: {
940       if (THPUtils_checkLong(obj)) {
941         return true;
942       }
943       if (THPVariable_Check(obj)) {
944         const auto& var = THPVariable_Unpack(obj);
945         return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) &&
946             !var.requires_grad() && var.dim() == 0;
947       }
948       if (torch::is_symint(py::handle(obj))) {
949         // This will induce a guard
950         return true;
951       }
952       return false;
953     }
954     case ParameterType::DIMNAME:
955       return THPUtils_checkDimname(obj);
956     case ParameterType::DIMNAME_LIST: {
957       if (THPUtils_checkDimnameList(obj)) {
958         return true;
959       }
960       // if a size is specified (e.g. DimnameList[1]) we also allow passing a
961       // single Dimname
962       return size == 1 && THPUtils_checkDimname(obj);
963     }
964     case ParameterType::TENSOR_LIST: {
965       return is_tensor_list_and_append_overloaded(
966           obj, &overloaded_args, argnum, true /* throw_error */);
967     }
968     case ParameterType::FLOAT_LIST:
969       return is_float_or_complex_list(obj);
970     case ParameterType::GENERATOR:
971       return THPGenerator_Check(obj);
972     case ParameterType::BOOL:
973       return PyBool_Check(obj);
974     case ParameterType::STORAGE:
975       return isStorage(obj);
976     case ParameterType::PYOBJECT:
977       return true;
978     case ParameterType::SCALARTYPE:
979       return THPDtype_Check(obj) || THPPythonScalarType_Check(obj);
980     case ParameterType::LAYOUT:
981       return THPLayout_Check(obj);
982     case ParameterType::MEMORY_FORMAT:
983       return THPMemoryFormat_Check(obj);
984     case ParameterType::QSCHEME:
985       return THPQScheme_Check(obj);
986     case ParameterType::DEVICE:
987       // Allow symint to be passed in as device, but we'll specialize and
988       // guard in this case.
989       return THPUtils_checkLong(obj) || THPUtils_checkString(obj) ||
990           THPDevice_Check(obj) || torch::is_symint(py::handle(obj));
991     case ParameterType::STREAM:
992       return THPStream_Check(obj);
993     case ParameterType::STRING:
994       return THPUtils_checkString(obj);
995     case ParameterType::SCALAR_LIST:
996       return is_scalar_list(obj);
997     case ParameterType::SYM_INT:
998       return is_int_or_symint(obj);
999     // Allow SymInt where int is expected; we'll guard in this case
1000     case ParameterType::INT_LIST:
1001     case ParameterType::SYM_INT_LIST:
1002       return is_int_or_symint_list(obj, size, failed_idx);
1003     case ParameterType::DISPATCH_KEY_SET:
1004       return py::isinstance<c10::DispatchKeySet>(py::handle(obj));
1005     default:
1006       throw std::runtime_error("unknown parameter type");
1007   }
1008 }
1009 
1010 // WARNING: these strings are parsed invalid_arguments.cpp
type_name() const1011 std::string FunctionParameter::type_name() const {
1012   switch (type_) {
1013     case ParameterType::TENSOR:
1014       return "Tensor";
1015     case ParameterType::SCALAR:
1016       return "Number";
1017     case ParameterType::INT64:
1018     // NB: SymInt is intentionally not mentioned here, as conventional user
1019     // use will only know about ints
1020     case ParameterType::SYM_INT:
1021       return "int";
1022     case ParameterType::DOUBLE:
1023       return "float";
1024     case ParameterType::COMPLEX:
1025       return "complex";
1026     case ParameterType::TENSOR_LIST:
1027       return "tuple of Tensors";
1028     case ParameterType::INT_LIST:
1029       return "tuple of ints";
1030     case ParameterType::FLOAT_LIST:
1031       return "tuple of floats";
1032     case ParameterType::GENERATOR:
1033       return "torch.Generator";
1034     case ParameterType::BOOL:
1035       return "bool";
1036     case ParameterType::STORAGE:
1037       return "torch.Storage";
1038     case ParameterType::PYOBJECT:
1039       return "object";
1040     case ParameterType::SCALARTYPE:
1041       return "torch.dtype";
1042     case ParameterType::LAYOUT:
1043       return "torch.layout";
1044     case ParameterType::MEMORY_FORMAT:
1045       return "torch.memory_format";
1046     case ParameterType::QSCHEME:
1047       return "torch.qscheme";
1048     case ParameterType::DEVICE:
1049       return "torch.device";
1050     case ParameterType::STRING:
1051       return "str";
1052     case ParameterType::DIMNAME:
1053       return "name";
1054     case ParameterType::DIMNAME_LIST:
1055       return "tuple of names";
1056     case ParameterType::SCALAR_LIST:
1057       return "tuple of Scalars";
1058     case ParameterType::SYM_INT_LIST:
1059       return "tuple of ints";
1060     case ParameterType::DISPATCH_KEY_SET:
1061       return "DispatchKeySet";
1062     default:
1063       throw std::runtime_error("unknown parameter type");
1064   }
1065 }
1066 
parse_as_integer(const std::string & s)1067 static inline std::optional<int64_t> parse_as_integer(const std::string& s) {
1068   if (s.empty())
1069     return std::nullopt;
1070   char* str_end = nullptr;
1071   long ans = strtol(s.c_str(), &str_end, 0);
1072   // *str_end == 0 if the entire string was parsed as an integer.
1073   return (*str_end == 0) ? std::optional<int64_t>(ans) : std::nullopt;
1074 }
1075 
1076 /*
1077 Parse default value of IntArrayRef declared at native_functions.yaml
1078 
1079 There are two kinds of default values:
1080 1. IntArrayRef[2] x=1 (where size=2, value={1,1}
1081 2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be
1082 space after comma since native_parse.py uses ', ' to split args)
1083 */
parse_intlist_args(const std::string & s,int64_t size)1084 static inline std::vector<int64_t> parse_intlist_args(
1085     const std::string& s,
1086     int64_t size) {
1087   size_t n = s.size();
1088 
1089   if (s.empty())
1090     return std::vector<int64_t>();
1091 
1092   // case 1. s is an int (e.g., s=2)
1093   if (s[0] != '{') {
1094     TORCH_CHECK(size > 0, "Incorrect size of IntArrayRef: ", size);
1095     return std::vector<int64_t>(size, std::stol(s));
1096   }
1097 
1098   // case 2. s is a list of dims (e.g., s={1,2})
1099 
1100   // since already checked left brace '{' above, here only checks right brace
1101   // '}'
1102   TORCH_CHECK(
1103       s[n - 1] == '}',
1104       "Default value of IntArrayRef is missing right brace '}', found ",
1105       s[n - 1]);
1106 
1107   auto args = std::vector<int64_t>();
1108   std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
1109   std::string tok;
1110 
1111   while (std::getline(ss, tok, ',')) {
1112     args.emplace_back(std::stol(tok));
1113   }
1114   return args;
1115 }
1116 
1117 // Parse a string literal to remove quotes and escape sequences
parse_string_literal(c10::string_view str)1118 static std::string parse_string_literal(c10::string_view str) {
1119   TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
1120 
1121   if (str.front() == '"') {
1122     TORCH_CHECK(
1123         str.back() == '"', "Mismatched quotes in string default: ", str);
1124   } else {
1125     TORCH_CHECK(
1126         str.front() == '\'' && str.back() == '\'',
1127         "Invalid quotes in string default: ",
1128         str)
1129   }
1130 
1131   std::string parsed;
1132   parsed.reserve(str.size());
1133   for (size_t i = 1; i < str.size() - 1;) {
1134     if (str[i] != '\\') {
1135       parsed.push_back(str[i]);
1136       ++i;
1137       continue;
1138     }
1139 
1140     // Handle escape sequences
1141     TORCH_CHECK(
1142         i < str.size() - 2, "String ends with escaped final quote: ", str)
1143     char c = str[i + 1];
1144     switch (c) {
1145       case '\\':
1146       case '\'':
1147       case '\"':
1148         break;
1149       case 'a':
1150         c = '\a';
1151         break;
1152       case 'b':
1153         c = '\b';
1154         break;
1155       case 'f':
1156         c = '\f';
1157         break;
1158       case 'n':
1159         c = '\n';
1160         break;
1161       case 'v':
1162         c = '\v';
1163         break;
1164       case 't':
1165         c = '\t';
1166         break;
1167       default:
1168         TORCH_CHECK(
1169             false,
1170             "Unsupported escape sequence in string default: \\",
1171             str[i + 1]);
1172     }
1173     parsed.push_back(c);
1174     i += 2;
1175   }
1176   return parsed;
1177 }
1178 
set_default_str(const std::string & str)1179 void FunctionParameter::set_default_str(const std::string& str) {
1180   if (str == "None") {
1181     allow_none = true;
1182   }
1183   if (type_ == ParameterType::TENSOR ||
1184       type_ == ParameterType::DISPATCH_KEY_SET) {
1185     if (str != "None") {
1186       throw std::runtime_error(
1187           "default value for Tensor must be none, got: " + str);
1188     }
1189   } else if (type_ == ParameterType::INT64 || type_ == ParameterType::SYM_INT) {
1190     default_int = atol(str.c_str());
1191   } else if (type_ == ParameterType::BOOL) {
1192     default_bool = (str == "True" || str == "true");
1193   } else if (type_ == ParameterType::DOUBLE) {
1194     default_double = atof(str.c_str());
1195   } else if (type_ == ParameterType::COMPLEX) {
1196     default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
1197     default_complex[1] = 0;
1198   } else if (type_ == ParameterType::SCALAR) {
1199     if (str != "None") {
1200       // we sometimes rely on integer-vs-float values, e.g. with arange.
1201       const auto as_integer = parse_as_integer(str);
1202       default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value())
1203                                               : at::Scalar(atof(str.c_str()));
1204     }
1205   } else if (
1206       type_ == ParameterType::INT_LIST ||
1207       type_ == ParameterType::SYM_INT_LIST) {
1208     if (str != "None") {
1209       default_intlist = parse_intlist_args(str, size);
1210     }
1211   } else if (type_ == ParameterType::FLOAT_LIST) {
1212     if (str != "None") {
1213       throw std::runtime_error("Defaults not supported for float[]");
1214     }
1215   } else if (type_ == ParameterType::SCALARTYPE) {
1216     if (str == "None") {
1217       default_scalartype = at::ScalarType::Undefined;
1218     } else if (str == "torch.int64") {
1219       default_scalartype = at::ScalarType::Long;
1220     } else {
1221       throw std::runtime_error("invalid default value for ScalarType: " + str);
1222     }
1223   } else if (type_ == ParameterType::LAYOUT) {
1224     if (str == "None") {
1225       TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
1226     } else if (str == "torch.strided") {
1227       default_layout = at::Layout::Strided;
1228     } else if (str == "torch.sparse_coo") {
1229       default_layout = at::Layout::Sparse;
1230     } else {
1231       throw std::runtime_error("invalid default value for layout: " + str);
1232     }
1233   } else if (type_ == ParameterType::DEVICE) {
1234     if (str != "None") {
1235       throw std::runtime_error("invalid device: " + str);
1236     }
1237   } else if (type_ == ParameterType::STREAM) {
1238     if (str != "None") {
1239       throw std::runtime_error("invalid stream: " + str);
1240     }
1241   } else if (type_ == ParameterType::STRING) {
1242     if (str != "None") {
1243       default_string = parse_string_literal(str);
1244     }
1245   }
1246   // These types weren't handled here before. Adding a default error
1247   // led to a lot of test failures so adding this skip for now.
1248   // We should correctly handle these though because it might be causing
1249   // silent failures.
1250   else if (type_ == ParameterType::TENSOR_LIST) { // NOLINT
1251     // throw std::runtime_error("Invalid Tensor List");
1252   } else if (type_ == ParameterType::GENERATOR) { // NOLINT
1253     // throw std::runtime_error("ParameterType::GENERATOR");
1254   } else if (type_ == ParameterType::PYOBJECT) { // NOLINT
1255     // throw std::runtime_error("ParameterType::PYOBJECT");
1256   } else if (type_ == ParameterType::MEMORY_FORMAT) { // NOLINT
1257     // throw std::runtime_error("ParameterType::MEMORY_FORMAT");
1258   } else if (type_ == ParameterType::DIMNAME) { // NOLINT
1259     // throw std::runtime_error("ParameterType::DIMNAME");
1260   } else if (type_ == ParameterType::DIMNAME_LIST) { // NOLINT
1261     // throw std::runtime_error("ParameterType::DIMNAME_LIST");
1262   } else if (type_ == ParameterType::SCALAR_LIST) { // NOLINT
1263     // throw std::runtime_error("ParameterType::SCALAR_LIST");
1264   } else if (type_ == ParameterType::STORAGE) { // NOLINT
1265     // throw std::runtime_error("ParameterType::STORAGE");
1266   } else if (type_ == ParameterType::QSCHEME) { // NOLINT
1267     // throw std::runtime_error("ParameterType::QSCHEME");
1268   } else {
1269     throw std::runtime_error("unknown parameter type");
1270   }
1271   default_value = str;
1272 }
1273 
1274 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
FunctionSignature(const std::string & fmt,int index)1275 FunctionSignature::FunctionSignature(const std::string& fmt, int index)
1276     : min_args(0),
1277       max_args(0),
1278       max_pos_args(0),
1279       index(index),
1280       hidden(false),
1281       deprecated(false) {
1282   auto open_paren = fmt.find('(');
1283   if (open_paren == std::string::npos) {
1284     throw std::runtime_error("missing opening parenthesis: " + fmt);
1285   }
1286   name = fmt.substr(0, open_paren);
1287 
1288   bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
1289 
1290   auto last_offset = open_paren + 1;
1291   bool keyword_only = false;
1292   bool done = false;
1293   while (!done) {
1294     auto offset = fmt.find(", ", last_offset);
1295     auto next_offset = offset + 2;
1296     if (offset == std::string::npos) {
1297       offset = fmt.find(')', last_offset);
1298       done = true;
1299       next_offset = offset + 1;
1300       // this 'if' happens for an empty parameter list, i.e. fn().
1301       if (offset == last_offset) {
1302         last_offset = next_offset;
1303         break;
1304       }
1305     }
1306     if (offset == std::string::npos) {
1307       throw std::runtime_error("missing closing parenthesis: " + fmt);
1308     }
1309     if (offset == last_offset) {
1310       throw std::runtime_error("malformed signature: " + fmt);
1311     }
1312 
1313     auto param_str = fmt.substr(last_offset, offset - last_offset);
1314     last_offset = next_offset;
1315     if (param_str == "*") {
1316       keyword_only = true;
1317     } else {
1318       params.emplace_back(param_str, keyword_only);
1319       params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
1320     }
1321   }
1322 
1323   if (fmt.substr(last_offset) == "|deprecated") {
1324     hidden = true;
1325     // TODO: raise warning when parsing deprecated signatures
1326     deprecated = true;
1327   } else if (fmt.substr(last_offset) == "|hidden") {
1328     hidden = true;
1329   }
1330 
1331   max_args = params.size();
1332 
1333   // count the number of non-optional args
1334   for (auto& param : params) {
1335     if (!param.optional) {
1336       min_args++;
1337     }
1338     if (!param.keyword_only) {
1339       max_pos_args++;
1340     }
1341   }
1342 }
1343 
toString() const1344 std::string FunctionSignature::toString() const {
1345   // optionals, etc.
1346   std::ostringstream ss;
1347   bool keyword_already = false;
1348   ss << "(";
1349   int i = 0;
1350   for (auto& param : params) {
1351     if (i != 0) {
1352       ss << ", ";
1353     }
1354     if (param.keyword_only && !keyword_already) {
1355       ss << "*, ";
1356       keyword_already = true;
1357     }
1358     ss << param.type_name() << " " << param.name;
1359     if (param.optional) {
1360       ss << " = " << param.default_value;
1361     }
1362     i++;
1363   }
1364   ss << ")";
1365   return ss.str();
1366 }
1367 
extra_args(const FunctionSignature & signature,Py_ssize_t nargs)1368 [[noreturn]] static void extra_args(
1369     const FunctionSignature& signature,
1370     Py_ssize_t nargs) {
1371   const auto max_pos_args = signature.max_pos_args;
1372   const auto min_args = signature.min_args;
1373   const long nargs_ = nargs;
1374   if (min_args != max_pos_args) {
1375     throw TypeError(
1376         "%s() takes from %zu to %zu positional arguments but %ld were given",
1377         signature.name.c_str(),
1378         min_args,
1379         max_pos_args,
1380         nargs_);
1381   }
1382   throw TypeError(
1383       "%s() takes %zu positional argument%s but %ld %s given",
1384       signature.name.c_str(),
1385       max_pos_args,
1386       max_pos_args == 1 ? "" : "s",
1387       nargs_,
1388       nargs == 1 ? "was" : "were");
1389 }
1390 
missing_args(const FunctionSignature & signature,int idx)1391 [[noreturn]] static void missing_args(
1392     const FunctionSignature& signature,
1393     int idx) {
1394   int num_missing = 0;
1395   std::stringstream ss;
1396 
1397   auto& params = signature.params;
1398   for (auto it = params.begin() + idx; it != params.end(); ++it) {
1399     if (!it->optional) {
1400       if (num_missing > 0) {
1401         ss << ", ";
1402       }
1403       ss << '"' << it->name << '"';
1404       num_missing++;
1405     }
1406   }
1407 
1408   throw TypeError(
1409       "%s() missing %d required positional argument%s: %s",
1410       signature.name.c_str(),
1411       num_missing,
1412       num_missing == 1 ? "s" : "",
1413       ss.str().c_str());
1414 }
1415 
find_param(FunctionSignature & signature,PyObject * name)1416 static Py_ssize_t find_param(FunctionSignature& signature, PyObject* name) {
1417   Py_ssize_t i = 0;
1418   for (auto& param : signature.params) {
1419     int cmp = PyObject_RichCompareBool(name, param.python_name, Py_EQ);
1420     if (cmp < 0) {
1421       throw python_error();
1422     } else if (cmp) {
1423       return i;
1424     }
1425     i++;
1426   }
1427   return -1;
1428 }
1429 
extra_kwargs(FunctionSignature & signature,PyObject * kwargs,Py_ssize_t num_pos_args)1430 [[noreturn]] static void extra_kwargs(
1431     FunctionSignature& signature,
1432     PyObject* kwargs,
1433     Py_ssize_t num_pos_args) {
1434   PyObject* key = nullptr;
1435   PyObject* value = nullptr;
1436   Py_ssize_t pos = 0;
1437 
1438   while (PyDict_Next(kwargs, &pos, &key, &value)) {
1439     if (!THPUtils_checkString(key)) {
1440       throw TypeError("keywords must be strings");
1441     }
1442 
1443     auto param_idx = find_param(signature, key);
1444     if (param_idx < 0) {
1445       throw TypeError(
1446           "%s() got an unexpected keyword argument '%s'",
1447           signature.name.c_str(),
1448           THPUtils_unpackString(key).c_str());
1449     }
1450 
1451     if (param_idx < num_pos_args) {
1452       throw TypeError(
1453           "%s() got multiple values for argument '%s'",
1454           signature.name.c_str(),
1455           THPUtils_unpackString(key).c_str());
1456     }
1457   }
1458 
1459   // this should never be hit
1460   throw TypeError("invalid keyword arguments");
1461 }
1462 
parse(PyObject * self,PyObject * args,PyObject * kwargs,PyObject * dst[],std::vector<PyObject * > & overloaded_args,bool raise_exception)1463 bool FunctionSignature::parse(
1464     PyObject* self,
1465     PyObject* args,
1466     PyObject* kwargs,
1467     PyObject* dst[], // NOLINT
1468     std::vector<PyObject*>& overloaded_args,
1469     bool raise_exception) {
1470   Py_ssize_t nargs = args ? PyTuple_GET_SIZE(args) : 0;
1471   auto remaining_kwargs = kwargs ? PyDict_Size(kwargs) : 0;
1472   size_t arg_pos = 0;
1473   bool allow_varargs_intlist = false;
1474 
1475   // if there is a single positional IntArrayRef argument, i.e. expand(..),
1476   // view(...), allow a var-args style IntArrayRef, so expand(5,3) behaves as
1477   // expand((5,3))
1478   if (max_pos_args == 1 &&
1479       (params[0].type_ == ParameterType::INT_LIST ||
1480        params[0].type_ == ParameterType::SYM_INT_LIST)) {
1481     allow_varargs_intlist = true;
1482   }
1483 
1484   if (static_cast<size_t>(nargs) > max_pos_args && !allow_varargs_intlist) {
1485     if (raise_exception) {
1486       // foo() takes takes 2 positional arguments but 3 were given
1487       extra_args(*this, nargs);
1488     }
1489     return false;
1490   }
1491 
1492   int i = 0;
1493   if (self != nullptr && check_has_torch_function(self, /*ignore_mode*/ true)) {
1494     append_overloaded_tensor(&overloaded_args, self);
1495   }
1496   for (auto& param : params) {
1497     PyObject* obj = nullptr;
1498     bool is_kwd = false;
1499     if (arg_pos < static_cast<size_t>(nargs)) {
1500       // extra positional args given after single positional IntArrayRef arg
1501       if (param.keyword_only) {
1502         if (raise_exception) {
1503           extra_args(*this, nargs);
1504         }
1505         return false;
1506       }
1507       obj = PyTuple_GET_ITEM(args, arg_pos);
1508     } else if (kwargs) {
1509       obj = PyDict_GetItem(kwargs, param.python_name);
1510       for (PyObject* numpy_name : param.numpy_python_names) {
1511         if (obj) {
1512           break;
1513         }
1514         obj = PyDict_GetItem(kwargs, numpy_name);
1515       }
1516       is_kwd = true;
1517     }
1518 
1519     int64_t failed_idx = -1;
1520     bool varargs_eligible = allow_varargs_intlist && arg_pos == 0 && !is_kwd;
1521     if ((!obj && param.optional) || (obj == Py_None && param.allow_none)) {
1522       dst[i++] = nullptr;
1523     } else if (!obj) {
1524       if (raise_exception) {
1525         // foo() missing 1 required positional argument: "b"
1526         missing_args(*this, i);
1527       }
1528       return false;
1529     } else if (param.check(obj, overloaded_args, i, &failed_idx)) {
1530       dst[i++] = obj;
1531       // XXX: the Variable check is necessary because sizes become tensors when
1532       // tracer is enabled. This behavior easily leads to ambiguities, and we
1533       // should avoid having complex signatures that make use of it...
1534     } else if (
1535         varargs_eligible &&
1536         (is_int_or_symint_list(args, param.size, &failed_idx))) {
1537       // take all positional arguments as this parameter
1538       // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
1539       dst[i++] = args;
1540       arg_pos = nargs;
1541       continue;
1542     } else if (raise_exception) {
1543       if (is_kwd) {
1544         // foo(): argument 'other' must be str, not int
1545         throw TypeError(
1546             "%s(): argument '%s' must be %s, not %s",
1547             name.c_str(),
1548             param.name.c_str(),
1549             param.type_name().c_str(),
1550             Py_TYPE(obj)->tp_name);
1551       } else {
1552         // foo(): argument 'other' (position 2) must be str, not int
1553         if (failed_idx != -1) {
1554           if (!(PyTuple_Check(obj) || PyList_Check(obj))) {
1555             TORCH_INTERNAL_ASSERT(varargs_eligible);
1556             obj = args;
1557           }
1558           TORCH_INTERNAL_ASSERT(failed_idx < PySequence_Size(obj));
1559           throw TypeError(
1560               "%s(): argument '%s' (position %ld) must be %s, but found element of type %s at pos %ld",
1561               name.c_str(),
1562               param.name.c_str(),
1563               static_cast<long>(arg_pos + 1),
1564               param.type_name().c_str(),
1565               Py_TYPE(py::reinterpret_steal<py::object>(
1566                           PySequence_GetItem(obj, failed_idx))
1567                           .ptr())
1568                   ->tp_name,
1569               static_cast<long>(failed_idx));
1570         }
1571         throw TypeError(
1572             "%s(): argument '%s' (position %ld) must be %s, not %s",
1573             name.c_str(),
1574             param.name.c_str(),
1575             static_cast<long>(arg_pos + 1),
1576             param.type_name().c_str(),
1577             Py_TYPE(obj)->tp_name);
1578       }
1579     } else {
1580       return false;
1581     }
1582 
1583     if (!is_kwd) {
1584       arg_pos++;
1585     } else if (obj) {
1586       remaining_kwargs--;
1587     }
1588   }
1589 
1590   if (remaining_kwargs > 0) {
1591     if (raise_exception) {
1592       // foo() got an unexpected keyword argument "b"
1593       extra_kwargs(*this, kwargs, nargs);
1594     }
1595     return false;
1596   }
1597   return true;
1598 }
1599 
PythonArgParser(const std::vector<std::string> & fmts,bool traceable)1600 PythonArgParser::PythonArgParser(
1601     const std::vector<std::string>& fmts,
1602     bool traceable)
1603     : max_args(0), traceable(traceable) {
1604   int index = 0;
1605   for (auto& fmt : fmts) {
1606     signatures_.emplace_back(fmt, index);
1607     ++index;
1608   }
1609   for (auto& signature : signatures_) {
1610     if (signature.max_args > max_args) {
1611       max_args = signature.max_args;
1612     }
1613   }
1614   if (!signatures_.empty()) {
1615     function_name = signatures_[0].name;
1616   }
1617 
1618   // Check deprecated signatures last
1619   std::stable_partition(
1620       signatures_.begin(), signatures_.end(), [](const FunctionSignature& sig) {
1621         return !sig.deprecated;
1622       });
1623 }
1624 
check_deprecated(const FunctionSignature & signature)1625 void PythonArgParser::check_deprecated(const FunctionSignature& signature) {
1626   if (signature.deprecated) {
1627     auto msg = c10::str(
1628         "This overload of ",
1629         signature.name,
1630         " is deprecated:\n\t",
1631         signature.name,
1632         signature.toString());
1633     auto signatures = get_signatures();
1634     if (!signatures.empty()) {
1635       msg += "\nConsider using one of the following signatures instead:";
1636       for (const auto& sig : signatures) {
1637         msg += "\n\t";
1638         msg += signature.name;
1639         msg += sig;
1640       }
1641     }
1642     TORCH_WARN_ONCE(msg);
1643   }
1644 }
1645 
raw_parse(PyObject * self,PyObject * args,PyObject * kwargs,PyObject * parsed_args[])1646 PythonArgs PythonArgParser::raw_parse(
1647     PyObject* self,
1648     PyObject* args,
1649     PyObject* kwargs,
1650     PyObject* parsed_args[]) { // NOLINT
1651   if (signatures_.size() == 1) {
1652     auto& signature = signatures_[0];
1653     std::vector<PyObject*> overloaded_args;
1654     signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
1655     check_deprecated(signature);
1656     return PythonArgs(
1657         traceable, signature, parsed_args, std::move(overloaded_args));
1658   }
1659 
1660   for (auto& signature : signatures_) {
1661     std::vector<PyObject*> overloaded_args;
1662     if (signature.parse(
1663             self, args, kwargs, parsed_args, overloaded_args, false)) {
1664       check_deprecated(signature);
1665       return PythonArgs(
1666           traceable, signature, parsed_args, std::move(overloaded_args));
1667     }
1668   }
1669 
1670   print_error(self, args, kwargs, parsed_args);
1671 }
1672 
print_error(PyObject * self,PyObject * args,PyObject * kwargs,PyObject * parsed_args[])1673 void PythonArgParser::print_error(
1674     PyObject* self,
1675     PyObject* args,
1676     PyObject* kwargs,
1677     PyObject* parsed_args[]) { // NOLINT
1678   size_t num_args =
1679       (args ? PyTuple_GET_SIZE(args) : 0) + (kwargs ? PyDict_Size(kwargs) : 0);
1680   std::vector<unsigned> plausible_idxs;
1681   unsigned i = 0;
1682   for (auto& signature : signatures_) {
1683     if (num_args >= signature.min_args && num_args <= signature.max_args &&
1684         !signature.hidden) {
1685       plausible_idxs.push_back(i);
1686     }
1687     i++;
1688   }
1689 
1690   if (plausible_idxs.size() == 1) {
1691     auto& signature = signatures_[plausible_idxs[0]];
1692     std::vector<PyObject*> overloaded_args;
1693     signature.parse(self, args, kwargs, parsed_args, overloaded_args, true);
1694   }
1695 
1696   auto options = get_signatures();
1697   auto msg =
1698       torch::format_invalid_args(args, kwargs, function_name + "()", options);
1699   throw TypeError("%s", msg.c_str());
1700 }
1701 
get_signatures() const1702 std::vector<std::string> PythonArgParser::get_signatures() const {
1703   std::vector<std::string> options;
1704   for (auto& signature : signatures_) {
1705     if (!signature.hidden) {
1706       options.push_back(signature.toString());
1707     }
1708   }
1709   return options;
1710 }
1711 
tensor_slow(int i)1712 at::Tensor PythonArgs::tensor_slow(int i) {
1713   PyObject* obj = args[i];
1714   if (!obj) {
1715     return at::Tensor();
1716   }
1717   if (THPVariable_Check(obj)) {
1718     return THPVariable_Unpack(obj);
1719   }
1720 
1721   bool save_symint = false;
1722   at::Scalar scalar;
1723   if (PyBool_Check(obj)) {
1724     scalar = at::Scalar(THPUtils_unpackBool(obj));
1725   } else if (THPUtils_checkLong(obj)) {
1726     int overflow = -1;
1727     long long value = PyLong_AsLongLongAndOverflow(obj, &overflow);
1728     if (value == -1 && PyErr_Occurred()) {
1729       throw python_error();
1730     }
1731     if (overflow != 0) {
1732       // try unsigned
1733       unsigned long long value = PyLong_AsUnsignedLongLong(obj);
1734       if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1735         throw python_error();
1736       }
1737       scalar = at::Scalar(static_cast<uint64_t>(value));
1738     } else {
1739       scalar = at::Scalar(static_cast<int64_t>(value));
1740     }
1741   } else if (PyComplex_Check(obj)) {
1742     scalar = at::Scalar(THPUtils_unpackComplexDouble(obj));
1743   } else if (THPUtils_checkDouble(obj)) {
1744     scalar = at::Scalar(THPUtils_unpackDouble(obj));
1745     // NB: we DO NOT put symbolic ints/floats into the Scalar itself,
1746     // because although Scalar supports SymInt/SymFloat, the subsequent
1747     // conversion to Tensor does not.  Instead, do it out of band.
1748   } else if (torch::is_symint(py::handle(obj))) {
1749     save_symint = true;
1750     // This scalar value doesn't matter, it shouldn't ever actually
1751     // get read out.  Make it a big and weird looking number to help
1752     // people figure out if there's aproblem.
1753     scalar = at::Scalar(7777777);
1754   } else if (torch::is_symfloat(py::handle(obj))) {
1755     save_symint = true;
1756     scalar = at::Scalar(std::numeric_limits<double>::quiet_NaN());
1757   } else if (torch::is_symbool(py::handle(obj))) {
1758     save_symint = true;
1759     scalar = at::Scalar(true);
1760   } else {
1761     // NB: Are you here because you passed None to a Variable method,
1762     // and you expected an undefined tensor to be returned?   Don't add
1763     // a test for Py_None here; instead, you need to mark the argument
1764     // as *allowing none*; you can do this by writing 'Tensor?' instead
1765     // of 'Tensor' in the ATen metadata.
1766     throw TypeError(
1767         "expected Tensor as argument %d, but got %s", i, Py_TYPE(obj)->tp_name);
1768   }
1769   at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove
1770   at::tracer::impl::NoTracerDispatchMode tracer_guard;
1771 
1772   at::Tensor tensor = scalar_to_tensor(scalar);
1773   tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
1774 
1775   if (save_symint) {
1776     auto py_tensor = py::cast(tensor);
1777     if (PyObject_SetAttrString(py_tensor.ptr(), "_wrapped_number", obj) < 0) {
1778       throw python_error();
1779     }
1780   }
1781 
1782   return tensor;
1783 }
1784 
scalar_slow(int i)1785 at::Scalar PythonArgs::scalar_slow(int i) {
1786   if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
1787     auto& var = THPVariable_Unpack(args[i]);
1788     jit::tracer::ArgumentStash::stashValue(
1789         signature.params[i].name, idx, var, c10::NumberType::get());
1790   }
1791 
1792   return scalar_slow(args[i]);
1793 }
1794 
scalar_slow(PyObject * arg)1795 at::Scalar PythonArgs::scalar_slow(PyObject* arg) {
1796   // Zero-dim tensors are converted to Scalars as-is. Note this doesn't
1797   // currently handle most NumPy scalar types except np.float64.
1798   if (THPVariable_Check(arg)) {
1799     return THPVariable_Unpack(arg).item();
1800   }
1801 
1802   if (THPUtils_checkLong(arg)) {
1803     int overflow = -1;
1804     long long value = PyLong_AsLongLongAndOverflow(arg, &overflow);
1805     if (value == -1 && PyErr_Occurred()) {
1806       throw python_error();
1807     }
1808     if (overflow != 0) {
1809       // try unsigned
1810       unsigned long long value = PyLong_AsUnsignedLongLong(arg);
1811       if (value == static_cast<unsigned long long>(-1) && PyErr_Occurred()) {
1812         throw python_error();
1813       }
1814       return at::Scalar(static_cast<uint64_t>(value));
1815     } else {
1816       return at::Scalar(static_cast<int64_t>(value));
1817     }
1818   }
1819 
1820   if (PyBool_Check(arg)) {
1821     return at::Scalar(THPUtils_unpackBool(arg));
1822   }
1823 
1824   if (PyComplex_Check(arg)) {
1825     return at::Scalar(THPUtils_unpackComplexDouble(arg));
1826   }
1827 
1828   if (torch::is_symint(arg)) {
1829     return at::Scalar(py::cast<c10::SymInt>(arg));
1830   }
1831 
1832   if (torch::is_symfloat(arg)) {
1833     return at::Scalar(py::cast<c10::SymFloat>(arg));
1834   }
1835 
1836   if (torch::is_symbool(arg)) {
1837     // Windows build fails with C2440: '<function-style-cast>'
1838     // when at:Scalar(py::cast<c10::SymBool>(arg))
1839     auto sym_bool = py::handle(arg).cast<c10::SymBool>();
1840     return at::Scalar(sym_bool);
1841   }
1842 
1843   return at::Scalar(THPUtils_unpackDouble(arg));
1844 }
1845 
1846 } // namespace torch
1847