1 #pragma once
2
3 // Parse arguments to Python functions implemented in C++
4 // This is similar to PyArg_ParseTupleAndKeywords(), but specifically handles
5 // the types relevant to PyTorch and distinguishes between overloaded function
6 // signatures.
7 //
8 // Example:
9 //
10 // static PythonArgParser parser({
11 // "norm(Scalar p, int64_t dim, bool keepdim=False)",
12 // "norm(Scalar p=2)",
13 // });
14 // ParsedArgs<3> parsed_args;
15 // auto r = parser.parse(args, kwargs, parsed_args);
16 // if (r.idx == 0) {
17 // norm(r.scalar(0), r.int64(1), r.bool(0));
18 // } else {
19 // norm(r.scalar(0));
20 // }
21 //
22 // We auto-generate most uses of PythonArgParser; the generated files
23 // are torch/csrc/autograd/generated/python_*.cpp
24 //
25 // Some gotchas that you should watch out for:
26 //
27 // - Note [Order of overloads matters]
28 // Order of overloads matters. A set of input arguments may
29 // bind to multiple argument specs; we will always pick the
30 // first one in PythonArgParser. However, when you are writing
31 // overloads in, e.g., native_functions.yaml, you don't have to
32 // worry about what order you write them, because the code
33 // generation logic always gives the overloads a canonical
34 // order, where Tensor overloads come first, before Scalar overloads.
35 // This logic is in sort_declarations in
36 // tools/autograd/gen_python_functions.py
37 //
38 // - Zero-dim tensors (e.g., torch.tensor(2)) bind to both
39 // Scalar and Tensor, UNLESS they require grad (in which case
40 // they only bind to Tensor).
41
42 #include <pybind11/pytypes.h>
43 #include <torch/csrc/python_headers.h>
44
45 #include <torch/csrc/Device.h>
46 #include <torch/csrc/Dtype.h>
47 #include <torch/csrc/DynamicTypes.h>
48 #include <torch/csrc/Exceptions.h>
49 #include <torch/csrc/Export.h>
50 #include <torch/csrc/Generator.h>
51 #include <torch/csrc/Layout.h>
52 #include <torch/csrc/MemoryFormat.h>
53 #include <torch/csrc/QScheme.h>
54 #include <torch/csrc/Stream.h>
55 #include <torch/csrc/autograd/python_variable.h>
56 #include <torch/csrc/autograd/variable.h>
57 #include <torch/csrc/dynamo/eval_frame.h>
58 #include <torch/csrc/jit/frontend/tracer.h>
59 #include <torch/csrc/python_dimname.h>
60 #include <torch/csrc/tensor/python_tensor.h>
61 #include <torch/csrc/utils/disable_torch_function.h>
62 #include <torch/csrc/utils/object_ptr.h>
63 #include <torch/csrc/utils/pybind.h>
64 #include <torch/csrc/utils/python_numbers.h>
65 #include <torch/csrc/utils/python_strings.h>
66 #include <torch/csrc/utils/python_symnode.h>
67 #include <torch/csrc/utils/six.h>
68
69 #include <ATen/DeviceAccelerator.h>
70 #include <ATen/PythonTorchFunctionTLS.h>
71 #include <ATen/core/Tensor.h>
72 #include <c10/util/Exception.h>
73 #include <c10/util/irange.h>
74
75 #include <c10/core/SymFloat.h>
76 #include <c10/core/SymNodeImpl.h>
77
78 #include <c10/core/DispatchKeySet.h>
79 #include <array>
80 #include <cstddef>
81 #include <string>
82 #include <vector>
83
THPUtils_checkScalar(PyObject * obj)84 inline bool THPUtils_checkScalar(PyObject* obj) {
85 #ifdef USE_NUMPY
86 if (torch::utils::is_numpy_scalar(obj)) {
87 return true;
88 }
89 #endif
90 return PyFloat_Check(obj) || PyLong_Check(obj) || PyComplex_Check(obj) ||
91 torch::is_symint(py::handle(obj)) ||
92 torch::is_symfloat(py::handle(obj)) || torch::is_symbool(py::handle(obj));
93 }
94
95 namespace torch {
96
97 bool should_allow_numbers_as_tensors(const std::string& name);
98
99 enum class ParameterType {
100 TENSOR,
101 SCALAR,
102 INT64,
103 SYM_INT,
104 DOUBLE,
105 COMPLEX,
106 TENSOR_LIST,
107 INT_LIST,
108 GENERATOR,
109 BOOL,
110 STORAGE,
111 PYOBJECT,
112 SCALARTYPE,
113 LAYOUT,
114 MEMORY_FORMAT,
115 DEVICE,
116 STREAM,
117 STRING,
118 DIMNAME,
119 DIMNAME_LIST,
120 QSCHEME,
121 FLOAT_LIST,
122 SCALAR_LIST,
123 SYM_INT_LIST,
124 DISPATCH_KEY_SET
125 };
126
127 struct FunctionParameter;
128 struct FunctionSignature;
129 struct PythonArgs;
130
131 // Contains bound Python arguments in declaration order
132 template <int N>
133 struct ParsedArgs {
ParsedArgsParsedArgs134 ParsedArgs() : args() {}
135 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
136 PyObject* args[N];
137 };
138
139 // A PythonArgParser contains a list of valid signatures. Instances are
140 // typically global variables and should be immutable.
141 struct PYBIND11_EXPORT PythonArgParser {
142 explicit PythonArgParser(
143 const std::vector<std::string>& fmts,
144 bool traceable = false);
145
146 // meant only for `torch` functions.
147 template <int N>
148 inline PythonArgs parse(
149 PyObject* self,
150 PyObject* args,
151 PyObject* kwargs,
152 ParsedArgs<N>& dst);
153
154 template <int N>
155 inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst);
156
157 inline PythonArgs parse(PyObject* self, ParsedArgs<0>& dst);
158
159 // Formatted strings of non-hidden signatures
160 std::vector<std::string> get_signatures() const;
161
162 private:
163 [[noreturn]] void print_error(
164 PyObject* self,
165 PyObject* args,
166 PyObject* kwargs,
167 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
168 PyObject* parsed_args[]);
169 void check_deprecated(const FunctionSignature& signature);
170 PythonArgs raw_parse(
171 PyObject* self,
172 PyObject* args,
173 PyObject* kwargs,
174 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
175 PyObject* parsed_args[]);
176
177 std::vector<FunctionSignature> signatures_;
178 std::string function_name;
179 size_t max_args;
180 bool traceable;
181 };
182
183 // FunctionSignature represents a single valid signature for a Python function.
184 // It is immutable once constructed. The contained data can be concurrently
185 // accessed by multiple calls.
186 struct FunctionSignature {
187 explicit FunctionSignature(const std::string& fmt, int index);
188
189 bool parse(
190 PyObject* self,
191 PyObject* args,
192 PyObject* kwargs,
193 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
194 PyObject* dst[],
195 std::vector<PyObject*>& overloaded_args,
196 bool raise_exception);
197
198 std::string toString() const;
199
200 std::string name;
201 std::vector<FunctionParameter> params;
202 size_t min_args;
203 size_t max_args;
204 size_t max_pos_args;
205 int index;
206 bool hidden;
207 bool deprecated;
208 };
209
210 // PythonArgs contains bound Python arguments for an actual invocation
211 // along with references to the matched signature.
212 struct PythonArgs {
PythonArgsPythonArgs213 PythonArgs(
214 bool traceable,
215 const FunctionSignature& signature,
216 PyObject** args,
217 std::vector<PyObject*> overloaded_args)
218 : idx(signature.index),
219 traceable(traceable),
220 signature(signature),
221 args(args),
222 overloaded_args(std::move(overloaded_args)) {}
223
224 int idx;
225 bool traceable;
226 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
227 const FunctionSignature& signature;
228 PyObject** args;
229 std::vector<PyObject*> overloaded_args; // NOTE: borrowed references
230
231 inline bool has_torch_function();
232 inline std::string get_func_name();
233 inline at::Tensor tensor(int i);
234 inline std::optional<at::Tensor> optionalTensor(int i);
235 inline at::Scalar scalar(int i);
236 inline at::Scalar scalarWithDefault(int i, const at::Scalar& default_scalar);
237 inline std::vector<at::Scalar> scalarlist(int i);
238 inline std::vector<at::Tensor> tensorlist(int i);
239 inline torch::List<std::optional<at::Tensor>> list_of_optional_tensors(int i);
240 template <int N>
241 inline std::array<at::Tensor, N> tensorlist_n(int i);
242 inline std::vector<int64_t> intlist(int i);
243 inline std::vector<c10::SymInt> symintlist(int i);
244 inline c10::OptionalArray<int64_t> intlistOptional(int i);
245 inline c10::OptionalArray<c10::SymInt> symintlistOptional(int i);
246 inline std::vector<int64_t> intlistWithDefault(
247 int i,
248 std::vector<int64_t> default_intlist);
249 inline std::optional<at::Generator> generator(int i);
250 inline at::Storage storage(int i);
251 inline at::Storage storage(
252 int i,
253 at::ScalarType& storage_scalar_type,
254 bool& is_typed_storage);
255 inline c10::Stream stream(int i);
256 inline at::ScalarType scalartype(int i);
257 inline at::ScalarType scalartypeWithDefault(
258 int i,
259 at::ScalarType default_scalartype);
260 inline std::optional<at::ScalarType> scalartypeOptional(int i);
261 inline std::optional<at::Scalar> scalarOptional(int i);
262 inline std::optional<int64_t> toInt64Optional(int i);
263 inline std::optional<c10::SymInt> toSymIntOptional(int i);
264 inline std::optional<bool> toBoolOptional(int i);
265 inline std::optional<double> toDoubleOptional(int i);
266 inline c10::OptionalArray<double> doublelistOptional(int i);
267 inline std::vector<double> doublelist(int i);
268 inline std::vector<double> getDoublelist(int i);
269 inline at::Layout layout(int i);
270 inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
271 inline std::optional<at::Layout> layoutOptional(int i);
272 inline at::Device device(int i);
273 inline at::Device deviceWithDefault(int i, const at::Device& default_device);
274 inline std::optional<at::Device> deviceOptional(int i);
275 inline at::Dimname dimname(int i);
276 inline std::vector<at::Dimname> dimnamelist(int i);
277 inline std::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
278 inline at::MemoryFormat memoryformat(int i);
279 inline std::optional<at::MemoryFormat> memoryformatOptional(int i);
280 inline at::QScheme toQScheme(int i);
281 inline std::string string(int i);
282 inline std::string stringWithDefault(int i, const std::string& default_str);
283 inline std::optional<std::string> stringOptional(int i);
284 inline c10::string_view stringView(int i);
285 inline c10::string_view stringViewWithDefault(
286 int i,
287 const c10::string_view default_str);
288 inline std::optional<c10::string_view> stringViewOptional(int i);
289 inline PyObject* pyobject(int i);
290 inline int64_t toInt64(int i);
291 inline c10::SymInt toSymInt(int i);
292 inline c10::SymBool toSymBool(int i);
293 inline int64_t toInt64WithDefault(int i, int64_t default_int);
294 inline double toDouble(int i);
295 inline double toDoubleWithDefault(int i, double default_double);
296 inline c10::complex<double> toComplex(int i);
297 inline c10::complex<double> toComplexWithDefault(
298 int i,
299 c10::complex<double> default_complex);
300 inline bool toBool(int i);
301 inline bool toBoolWithDefault(int i, bool default_bool);
302 inline bool isNone(int i);
303 inline std::optional<c10::DispatchKeySet> toDispatchKeySetOptional(int i);
304
305 private:
306 at::Tensor tensor_slow(int i);
307 at::Scalar scalar_slow(int i);
308 at::Scalar scalar_slow(PyObject* arg);
309 };
310
311 // FunctionParameter is a single formal parameter of a Python function.
312 // It is immutable once constructed.
313 struct FunctionParameter {
314 FunctionParameter(const std::string& fmt, bool keyword_only);
315
316 bool check(
317 PyObject* obj,
318 std::vector<PyObject*>& overloaded_args,
319 int argnum,
320 int64_t* failed_idx = nullptr);
321
322 void set_default_str(const std::string& str);
323 std::string type_name() const;
324
325 ParameterType type_;
326 bool optional;
327 bool allow_none;
328 bool keyword_only;
329 bool allow_numbers_as_tensors = false;
330 int size;
331 std::string name;
332 // having this as a raw PyObject * will presumably leak it, but these are only
333 // held by static objects anyway, and Py_Finalize can already be called when
334 // this is destructed.
335 PyObject* python_name;
336 // NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
337 at::SmallVector<PyObject*, 5> numpy_python_names;
338 at::Scalar default_scalar;
339 std::vector<int64_t> default_intlist;
340 std::string default_string;
341 union {
342 bool default_bool;
343 int64_t default_int;
344 double default_double;
345 // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
346 double default_complex[2]; // see Scalar
347 at::ScalarType default_scalartype;
348 at::Layout default_layout;
349 };
350 std::string default_value;
351 };
352
353 template <int N>
parse(PyObject * self,PyObject * args,PyObject * kwargs,ParsedArgs<N> & dst)354 inline PythonArgs PythonArgParser::parse(
355 PyObject* self,
356 PyObject* args,
357 PyObject* kwargs,
358 ParsedArgs<N>& dst) {
359 TORCH_CHECK_VALUE(
360 N >= max_args,
361 "PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected ",
362 max_args,
363 " (got ",
364 N,
365 ")");
366 return raw_parse(self, args, kwargs, dst.args);
367 }
368
369 template <int N>
parse(PyObject * args,PyObject * kwargs,ParsedArgs<N> & dst)370 inline PythonArgs PythonArgParser::parse(
371 PyObject* args,
372 PyObject* kwargs,
373 ParsedArgs<N>& dst) {
374 return parse(nullptr, args, kwargs, dst);
375 }
376
parse(PyObject * self,ParsedArgs<0> & dst)377 inline PythonArgs PythonArgParser::parse(PyObject* self, ParsedArgs<0>& dst) {
378 return parse(self, nullptr, nullptr, dst);
379 }
380
has_torch_function()381 inline bool PythonArgs::has_torch_function() {
382 return !overloaded_args.empty() || at::impl::torch_function_mode_enabled();
383 }
384
get_func_name()385 inline std::string PythonArgs::get_func_name() {
386 return signature.name;
387 }
388
389 // TODO: this can return MaybeOwned
tensor(int i)390 inline at::Tensor PythonArgs::tensor(int i) {
391 if (args[i] && THPVariable_CheckExact(args[i])) {
392 return THPVariable_Unpack(args[i]);
393 }
394 return tensor_slow(i);
395 }
396
optionalTensor(int i)397 inline std::optional<at::Tensor> PythonArgs::optionalTensor(int i) {
398 at::Tensor t = tensor(i);
399 // NOLINTNEXTLINE(bugprone-branch-clone)
400 if (t.defined()) {
401 return t;
402 } else {
403 return std::nullopt;
404 }
405 }
406
scalar(int i)407 inline at::Scalar PythonArgs::scalar(int i) {
408 if (!args[i])
409 return signature.params[i].default_scalar;
410 return scalar_slow(i);
411 }
412
scalarlist(int i)413 inline std::vector<at::Scalar> PythonArgs::scalarlist(int i) {
414 if (!args[i])
415 return std::vector<at::Scalar>();
416 auto tuple = six::isTuple(args[i]);
417 THPObjectPtr arg = six::maybeAsTuple(args[i]);
418 // NOLINTNEXTLINE(bugprone-branch-clone)
419 auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
420 std::vector<at::Scalar> res(size);
421 for (const auto idx : c10::irange(size)) {
422 PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
423 : PyList_GET_ITEM(arg.get(), idx);
424 res[idx] = scalar_slow(obj);
425 }
426 return res;
427 }
428
scalarWithDefault(int i,const at::Scalar & default_scalar)429 inline at::Scalar PythonArgs::scalarWithDefault(
430 int i,
431 const at::Scalar& default_scalar) {
432 if (!args[i])
433 return default_scalar;
434 return scalar_slow(i);
435 }
436
scalarOptional(int i)437 inline std::optional<at::Scalar> PythonArgs::scalarOptional(int i) {
438 if (!args[i])
439 return std::nullopt;
440 return scalar_slow(i);
441 }
442
tensorlist(int i)443 inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
444 if (!args[i])
445 return std::vector<at::Tensor>();
446 auto tuple = six::isTuple(args[i]);
447 THPObjectPtr arg = six::maybeAsTuple(args[i]);
448 // NOLINTNEXTLINE(bugprone-branch-clone)
449 auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
450 std::vector<at::Tensor> res(size);
451 for (const auto idx : c10::irange(size)) {
452 PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
453 : PyList_GET_ITEM(arg.get(), idx);
454 // This is checked by the argument parser so it's safe to cast without
455 // checking if this is a tensor first
456 res[idx] = THPVariable_Unpack(obj);
457 }
458 return res;
459 }
460
461 inline torch::List<std::optional<at::Tensor>> PythonArgs::
list_of_optional_tensors(int i)462 list_of_optional_tensors(int i) {
463 if (!args[i])
464 return torch::List<std::optional<at::Tensor>>();
465 auto tuple = six::isTuple(args[i]);
466 THPObjectPtr arg = six::maybeAsTuple(args[i]);
467 // NOLINTNEXTLINE(bugprone-branch-clone)
468 auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
469 torch::List<std::optional<at::Tensor>> res;
470 res.reserve(size);
471 for (const auto idx : c10::irange(size)) {
472 PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
473 : PyList_GET_ITEM(arg.get(), idx);
474 // This is checked by the argument parser so it's safe to cast without
475 // checking if this is a tensor first
476 res.push_back(THPVariable_Unpack(obj));
477 }
478 return res;
479 }
480
481 template <int N>
tensorlist_n(int i)482 inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
483 auto res = std::array<at::Tensor, N>();
484 if (!args[i])
485 return res;
486 auto tuple = six::isTuple(args[i]);
487 THPObjectPtr arg = six::maybeAsTuple(args[i]);
488 // NOLINTNEXTLINE(bugprone-branch-clone)
489 auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
490 if (size != N) {
491 throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
492 }
493 for (const auto idx : c10::irange(size)) {
494 PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx)
495 : PyList_GET_ITEM(arg.get(), idx);
496 // This is checked by the argument parser so it's safe to cast without
497 // checking if this is a tensor first
498 res[idx] = THPVariable_Unpack(obj);
499 }
500 return res;
501 }
502
intlist(int i)503 inline std::vector<int64_t> PythonArgs::intlist(int i) {
504 return intlistWithDefault(i, signature.params[i].default_intlist);
505 }
506
toPyObject(const c10::SymInt & symint)507 inline PyObject* toPyObject(const c10::SymInt& symint) {
508 if (symint.is_symbolic()) {
509 auto r = py::cast(symint).release().ptr();
510 TORCH_INTERNAL_ASSERT(r);
511 return r;
512 } else {
513 auto m = symint.maybe_as_int();
514 return THPUtils_packInt64(*m);
515 }
516 }
517
518 inline void throw_intlist_exception(
519 const torch::PythonArgs* args,
520 size_t i,
521 PyObject* obj,
522 size_t idx,
523 const std::exception& e = python_error()) {
524 std::string error = strlen(e.what())
525 ? e.what()
526 : std::string("type must be ") + args->signature.params[i].type_name() +
527 ",but got " + Py_TYPE(obj)->tp_name;
528 throw TypeError(
529 "%s(): argument '%s' failed to unpack the object at pos %zu with error \"%s\"",
530 args->signature.name.c_str(),
531 args->signature.params[i].name.c_str(),
532 idx + 1,
533 error.c_str());
534 }
535
symintlist(int i)536 inline std::vector<c10::SymInt> PythonArgs::symintlist(int i) {
537 if (!args[i]) {
538 return c10::fmap(signature.params[i].default_intlist, [](int64_t di) {
539 return c10::SymInt(di);
540 });
541 }
542
543 const auto size1 = signature.params[i].size;
544 if (size1 > 0 && THPUtils_checkLong(args[i])) {
545 return std::vector<c10::SymInt>(
546 size1, c10::SymInt(THPUtils_unpackLong(args[i])));
547 }
548
549 if (size1 > 0 && torch::is_symint(py::handle(args[i]))) {
550 auto si = py::handle(args[i]).cast<c10::SymInt>();
551 return std::vector<c10::SymInt>(size1, si);
552 }
553
554 PyObject* arg = args[i];
555 auto tuple = PyTuple_Check(arg);
556 // NOLINTNEXTLINE(bugprone-branch-clone)
557 const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
558 std::vector<c10::SymInt> res;
559 res.reserve(size2);
560 for (const auto idx : c10::irange(size2)) {
561 PyObject* obj =
562 tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
563
564 // Elements of torch.Size are tensors during tracing, and we need to
565 // record extra information before they are turned into an IntArrayRef
566 if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
567 auto& var = THPVariable_Unpack(obj);
568 jit::tracer::ArgumentStash::stashIntArrayRefElem(
569 signature.params[i].name, size2, idx, var);
570 try {
571 res.emplace_back(var.item<int64_t>());
572 continue;
573 } catch (std::exception& e) {
574 throw_intlist_exception(this, i, obj, idx, e);
575 }
576 continue;
577 } else {
578 // convert tensor to scalar outside of try / catch,
579 // so that Tensor subclass exceptions will not be caught.
580 if (THPUtils_checkLongExact(obj)) {
581 // Fast path for plain numbers
582 try {
583 res.emplace_back(THPUtils_unpackLong(obj));
584 } catch (std::exception& e) {
585 throw_intlist_exception(this, i, obj, idx, e);
586 }
587 } else if (THPVariable_Check(obj)) {
588 auto& var = THPVariable_Unpack(obj);
589 if (var.numel() != 1 ||
590 !at::isIntegralType(
591 var.dtype().toScalarType(), /*include_bool*/ true)) {
592 throw_intlist_exception(this, i, obj, idx);
593 }
594 auto scalar = var.item();
595 TORCH_CHECK(scalar.isIntegral(/*include bool*/ false));
596 res.push_back(scalar.toSymInt());
597 } else {
598 try {
599 if (is_symint(py::handle(obj))) {
600 res.push_back(py::handle(obj).cast<c10::SymInt>());
601 } else {
602 res.emplace_back(THPUtils_unpackIndex(obj));
603 }
604 } catch (std::exception& e) {
605 throw_intlist_exception(this, i, obj, idx, e);
606 }
607 }
608 }
609 }
610
611 return res;
612 }
613
intlistWithDefault(int i,std::vector<int64_t> default_intlist)614 inline std::vector<int64_t> PythonArgs::intlistWithDefault(
615 int i,
616 std::vector<int64_t> default_intlist) {
617 if (!args[i])
618 return default_intlist;
619 PyObject* arg = args[i];
620 const auto size1 = signature.params[i].size;
621 if (size1 > 0 && THPUtils_checkLong(arg)) {
622 return std::vector<int64_t>(size1, THPUtils_unpackLong(arg));
623 }
624 if (size1 > 0 && torch::is_symint(py::handle(arg))) {
625 return std::vector<int64_t>(
626 size1,
627 py::handle(arg).cast<c10::SymInt>().guard_int(__FILE__, __LINE__));
628 }
629 auto tuple = PyTuple_Check(arg);
630 // NOLINTNEXTLINE(bugprone-branch-clone)
631 const auto size2 = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
632 std::vector<int64_t> res(size2);
633 for (const auto idx : c10::irange(size2)) {
634 PyObject* obj =
635 tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
636 // Elements of torch.Size are tensors during tracing, and we need to
637 // record extra information before they are turned into an IntArrayRef
638 if (traceable && jit::tracer::isTracing() && THPVariable_Check(obj)) {
639 auto& var = THPVariable_Unpack(obj);
640 jit::tracer::ArgumentStash::stashIntArrayRefElem(
641 signature.params[i].name, size2, idx, var);
642 try {
643 res[idx] = var.item<int64_t>();
644 continue;
645 } catch (std::exception& e) {
646 throw_intlist_exception(this, i, obj, idx, e);
647 }
648 } else {
649 // convert tensor to scalar outside of try / catch,
650 // so that Tensor subclass exceptions will not be caught.
651 if (THPUtils_checkLongExact(obj)) {
652 // Fast path for plain numbers
653 try {
654 res[idx] = THPUtils_unpackLong(obj);
655 } catch (std::exception& e) {
656 throw_intlist_exception(this, i, obj, idx, e);
657 }
658 } else if (torch::is_symint(py::handle(obj))) {
659 res[idx] = py::cast<c10::SymInt>(py::handle(obj))
660 .guard_int(__FILE__, __LINE__);
661 } else if (THPVariable_Check(obj)) {
662 auto& var = THPVariable_Unpack(obj);
663 if (var.numel() != 1 ||
664 !at::isIntegralType(
665 var.dtype().toScalarType(), /*include_bool*/ true)) {
666 throw_intlist_exception(this, i, obj, idx);
667 }
668 res[idx] = var.item<int64_t>();
669 } else {
670 try {
671 res[idx] = THPUtils_unpackIndex(obj);
672 } catch (std::exception& e) {
673 throw_intlist_exception(this, i, obj, idx, e);
674 }
675 }
676 }
677 }
678 return res;
679 }
680
intlistOptional(int i)681 inline c10::OptionalArray<int64_t> PythonArgs::intlistOptional(int i) {
682 if (!args[i]) {
683 return {};
684 }
685 return intlist(i);
686 }
687
symintlistOptional(int i)688 inline c10::OptionalArray<c10::SymInt> PythonArgs::symintlistOptional(int i) {
689 if (!args[i]) {
690 return {};
691 }
692 return symintlist(i);
693 }
694
getDoublelist(int i)695 inline std::vector<double> PythonArgs::getDoublelist(int i) {
696 PyObject* arg = args[i];
697 auto tuple = PyTuple_Check(arg);
698 // NOLINTNEXTLINE(bugprone-branch-clone)
699 auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
700 std::vector<double> res(size);
701 for (const auto idx : c10::irange(size)) {
702 PyObject* obj =
703 tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
704 try {
705 res[idx] = THPUtils_unpackDouble(obj);
706 } catch (const std::exception&) {
707 throw TypeError(
708 "%s(): argument '%s' must be %s, but found element of type %s at pos %zu",
709 signature.name.c_str(),
710 signature.params[i].name.c_str(),
711 signature.params[i].type_name().c_str(),
712 Py_TYPE(obj)->tp_name,
713 idx + 1);
714 }
715 }
716 return res;
717 }
718
doublelistOptional(int i)719 inline c10::OptionalArray<double> PythonArgs::doublelistOptional(int i) {
720 if (!args[i]) {
721 return {};
722 }
723 return this->getDoublelist(i);
724 }
725
doublelist(int i)726 inline std::vector<double> PythonArgs::doublelist(int i) {
727 if (!args[i]) {
728 return {};
729 }
730 return this->getDoublelist(i);
731 }
732
toDispatchKeySetOptional(int i)733 inline std::optional<c10::DispatchKeySet> PythonArgs::toDispatchKeySetOptional(
734 int i) {
735 if (!args[i]) {
736 return {};
737 }
738 return py::cast<c10::DispatchKeySet>(py::handle(args[i]));
739 }
740
scalartypeWithDefault(int i,at::ScalarType default_scalartype)741 inline at::ScalarType PythonArgs::scalartypeWithDefault(
742 int i,
743 at::ScalarType default_scalartype) {
744 if (!args[i])
745 return default_scalartype;
746 return scalartype(i);
747 }
748
toScalarType(PyObject * obj)749 inline at::ScalarType toScalarType(PyObject* obj) {
750 if (obj == (PyObject*)&PyFloat_Type) {
751 return at::ScalarType::Double;
752 }
753 if (obj == (PyObject*)&PyBool_Type) {
754 return at::ScalarType::Bool;
755 }
756 if (obj == (PyObject*)&PyLong_Type) {
757 return at::ScalarType::Long;
758 }
759 if (obj == (PyObject*)&PyComplex_Type) {
760 return at::ScalarType::ComplexDouble;
761 }
762 return reinterpret_cast<THPDtype*>(obj)->scalar_type;
763 }
764
scalartype(int i)765 inline at::ScalarType PythonArgs::scalartype(int i) {
766 if (!args[i]) {
767 auto scalartype = signature.params[i].default_scalartype;
768 return (scalartype == at::ScalarType::Undefined)
769 ? torch::tensors::get_default_scalar_type()
770 : scalartype;
771 }
772 PyObject* obj = args[i];
773 return toScalarType(obj);
774 }
775
scalartypeOptional(int i)776 inline std::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) {
777 if (!args[i])
778 return std::nullopt;
779 return scalartype(i);
780 }
781
toLayout(PyObject * obj)782 inline at::Layout toLayout(PyObject* obj) {
783 const auto layout = reinterpret_cast<THPLayout*>(obj);
784 return layout->layout;
785 }
786
layout(int i)787 inline at::Layout PythonArgs::layout(int i) {
788 if (!args[i])
789 return signature.params[i].default_layout;
790 return toLayout(args[i]);
791 }
792
layoutWithDefault(int i,at::Layout default_layout)793 inline at::Layout PythonArgs::layoutWithDefault(
794 int i,
795 at::Layout default_layout) {
796 if (!args[i])
797 return default_layout;
798 return layout(i);
799 }
800
layoutOptional(int i)801 inline std::optional<at::Layout> PythonArgs::layoutOptional(int i) {
802 if (!args[i])
803 return std::nullopt;
804 return layout(i);
805 }
806
deviceFromLong(int64_t device_index)807 inline at::Device deviceFromLong(int64_t device_index) {
808 TORCH_CHECK(device_index >= 0, "Device index must not be negative");
809 return at::Device(
810 at::getAccelerator(true).value(),
811 static_cast<c10::DeviceIndex>(device_index));
812 }
813
toDevice(PyObject * obj)814 inline at::Device toDevice(PyObject* obj) {
815 if (THPDevice_Check(obj)) {
816 const auto device = reinterpret_cast<THPDevice*>(obj);
817 return device->device;
818 }
819 if (THPUtils_checkLong(obj)) {
820 return deviceFromLong(THPUtils_unpackLong(obj));
821 }
822 if (torch::is_symint(py::handle(obj))) {
823 auto device_index =
824 py::cast<c10::SymInt>(py::handle(obj)).guard_int(__FILE__, __LINE__);
825 return deviceFromLong(device_index);
826 }
827 const std::string& device_str = THPUtils_unpackString(obj);
828 return at::Device(device_str);
829 }
830
device(int i)831 inline at::Device PythonArgs::device(int i) {
832 if (!args[i]) {
833 return torch::tensors::get_default_device();
834 }
835 return toDevice(args[i]);
836 }
837
deviceWithDefault(int i,const at::Device & default_device)838 inline at::Device PythonArgs::deviceWithDefault(
839 int i,
840 const at::Device& default_device) {
841 if (!args[i])
842 return default_device;
843 return device(i);
844 }
845
deviceOptional(int i)846 inline std::optional<at::Device> PythonArgs::deviceOptional(int i) {
847 if (!args[i])
848 return std::nullopt;
849 return device(i);
850 }
851
dimname(int i)852 inline at::Dimname PythonArgs::dimname(int i) {
853 TORCH_INTERNAL_ASSERT(args[i] != nullptr);
854 return THPDimname_parse(args[i]);
855 }
856
parseDimnameList(PyObject * arg)857 inline std::vector<at::Dimname> parseDimnameList(PyObject* arg) {
858 auto tuple = PyTuple_Check(arg);
859 // NOLINTNEXTLINE(bugprone-branch-clone)
860 auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
861 std::vector<at::Dimname> res;
862 res.reserve(size);
863 for (const auto idx : c10::irange(size)) {
864 PyObject* obj =
865 tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
866 res.push_back(THPDimname_parse(obj));
867 }
868 return res;
869 }
870
871 inline std::optional<std::vector<at::Dimname>> PythonArgs::
toDimnameListOptional(int i)872 toDimnameListOptional(int i) {
873 if (!args[i])
874 return std::nullopt;
875 return parseDimnameList(args[i]);
876 }
877
dimnamelist(int i)878 inline std::vector<at::Dimname> PythonArgs::dimnamelist(int i) {
879 TORCH_INTERNAL_ASSERT(args[i]);
880 PyObject* arg = args[i];
881 auto size = signature.params[i].size;
882 TORCH_INTERNAL_ASSERT(size == 0 || size == 1);
883 if (size == 1 && THPUtils_checkDimname(arg)) {
884 return {THPDimname_parse(arg)};
885 }
886 return parseDimnameList(arg);
887 }
888
memoryformat(int i)889 inline at::MemoryFormat PythonArgs::memoryformat(int i) {
890 if (!args[i])
891 return at::MemoryFormat::Contiguous;
892 TORCH_CHECK(
893 THPMemoryFormat_Check(args[i]),
894 "memory_format arg must be an instance of the torch.memory_format");
895 const auto memory_format = reinterpret_cast<THPMemoryFormat*>(args[i]);
896 return memory_format->memory_format;
897 }
898
memoryformatOptional(int i)899 inline std::optional<at::MemoryFormat> PythonArgs::memoryformatOptional(int i) {
900 if (!args[i])
901 return std::nullopt;
902 return memoryformat(i);
903 }
904
toQScheme(int i)905 inline at::QScheme PythonArgs::toQScheme(int i) {
906 if (!args[i])
907 return at::kPerTensorAffine;
908 TORCH_CHECK(
909 THPQScheme_Check(args[i]),
910 "qscheme arg must be an instance of the torch.qscheme");
911 const auto qscheme = reinterpret_cast<THPQScheme*>(args[i]);
912 return qscheme->qscheme;
913 }
914
string(int i)915 inline std::string PythonArgs::string(int i) {
916 return stringWithDefault(i, signature.params[i].default_string);
917 }
918
stringWithDefault(int i,const std::string & default_str)919 inline std::string PythonArgs::stringWithDefault(
920 int i,
921 const std::string& default_str) {
922 if (!args[i])
923 return default_str;
924 return THPUtils_unpackString(args[i]);
925 }
926
stringOptional(int i)927 inline std::optional<std::string> PythonArgs::stringOptional(int i) {
928 if (!args[i])
929 return std::nullopt;
930 return THPUtils_unpackString(args[i]);
931 }
932
stringView(int i)933 inline c10::string_view PythonArgs::stringView(int i) {
934 return stringViewWithDefault(i, signature.params[i].default_string);
935 }
936
stringViewWithDefault(int i,const c10::string_view default_str)937 inline c10::string_view PythonArgs::stringViewWithDefault(
938 int i,
939 const c10::string_view default_str) {
940 if (!args[i])
941 return default_str;
942 return THPUtils_unpackStringView(args[i]);
943 }
944
stringViewOptional(int i)945 inline std::optional<c10::string_view> PythonArgs::stringViewOptional(int i) {
946 if (!args[i])
947 return std::nullopt;
948 return THPUtils_unpackStringView(args[i]);
949 }
950
toInt64(int i)951 inline int64_t PythonArgs::toInt64(int i) {
952 if (!args[i])
953 return signature.params[i].default_int;
954 if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
955 auto& var = THPVariable_Unpack(args[i]);
956 jit::tracer::ArgumentStash::stashValue(
957 signature.params[i].name, idx, var, c10::IntType::get());
958 }
959 if (torch::is_symint(py::handle(args[i]))) {
960 return py::cast<c10::SymInt>(py::handle(args[i]))
961 .guard_int(__FILE__, __LINE__);
962 }
963 return THPUtils_unpackLong(args[i]);
964 }
965
toSymInt(int i)966 inline c10::SymInt PythonArgs::toSymInt(int i) {
967 if (!args[i]) {
968 return c10::SymInt(signature.params[i].default_int);
969 }
970
971 if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
972 auto& var = THPVariable_Unpack(args[i]);
973 jit::tracer::ArgumentStash::stashValue(
974 signature.params[i].name, idx, var, c10::IntType::get());
975 }
976
977 return py::cast<c10::SymInt>(py::handle(args[i]));
978 }
979
toSymBool(int i)980 inline c10::SymBool PythonArgs::toSymBool(int i) {
981 if (!args[i]) {
982 return c10::SymBool(signature.params[i].default_bool);
983 }
984 if (traceable && jit::tracer::isTracing() && THPVariable_Check(args[i])) {
985 auto& var = THPVariable_Unpack(args[i]);
986 jit::tracer::ArgumentStash::stashValue(
987 signature.params[i].name, idx, var, c10::BoolType::get());
988 }
989
990 return py::cast<c10::SymBool>(py::handle(args[i]));
991 }
992
toInt64WithDefault(int i,int64_t default_int)993 inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) {
994 if (!args[i])
995 return default_int;
996 return toInt64(i);
997 }
998
toInt64Optional(int i)999 inline std::optional<int64_t> PythonArgs::toInt64Optional(int i) {
1000 if (!args[i])
1001 return std::nullopt;
1002 return toInt64(i);
1003 }
1004
toSymIntOptional(int i)1005 inline std::optional<c10::SymInt> PythonArgs::toSymIntOptional(int i) {
1006 if (!args[i])
1007 return std::nullopt;
1008 return toSymInt(i);
1009 }
1010
toBoolOptional(int i)1011 inline std::optional<bool> PythonArgs::toBoolOptional(int i) {
1012 if (!args[i]) {
1013 return std::nullopt;
1014 }
1015 return toBool(i);
1016 }
1017
toDoubleOptional(int i)1018 inline std::optional<double> PythonArgs::toDoubleOptional(int i) {
1019 if (!args[i]) {
1020 return std::nullopt;
1021 }
1022 return toDouble(i);
1023 }
1024
toDouble(int i)1025 inline double PythonArgs::toDouble(int i) {
1026 if (!args[i])
1027 return signature.params[i].default_double;
1028 if (torch::is_symfloat(py::handle(args[i]))) {
1029 return py::cast<c10::SymFloat>(py::handle(args[i]))
1030 .guard_float(__FILE__, __LINE__);
1031 }
1032 if (torch::is_symint(py::handle(args[i]))) {
1033 return static_cast<double>(py::cast<c10::SymInt>(py::handle(args[i]))
1034 .guard_int(__FILE__, __LINE__));
1035 }
1036 return THPUtils_unpackDouble(args[i]);
1037 }
1038
toBool(int i)1039 inline bool PythonArgs::toBool(int i) {
1040 if (!args[i])
1041 return signature.params[i].default_bool;
1042 if (torch::is_symbool(py::handle(args[i]))) {
1043 return py::cast<c10::SymBool>(py::handle(args[i]))
1044 .guard_bool(__FILE__, __LINE__);
1045 }
1046 return args[i] == Py_True;
1047 }
1048
toDoubleWithDefault(int i,double default_double)1049 inline double PythonArgs::toDoubleWithDefault(int i, double default_double) {
1050 if (!args[i])
1051 return default_double;
1052 return toDouble(i);
1053 }
1054
toComplex(int i)1055 inline c10::complex<double> PythonArgs::toComplex(int i) {
1056 if (!args[i])
1057 return *(reinterpret_cast<const c10::complex<double>*>(
1058 signature.params[i].default_complex));
1059 return THPUtils_unpackComplexDouble(args[i]);
1060 }
1061
toComplexWithDefault(int i,c10::complex<double> default_value)1062 inline c10::complex<double> PythonArgs::toComplexWithDefault(
1063 int i,
1064 c10::complex<double> default_value) {
1065 if (!args[i])
1066 return default_value;
1067 return toComplex(i);
1068 }
1069
toBoolWithDefault(int i,bool default_bool)1070 inline bool PythonArgs::toBoolWithDefault(int i, bool default_bool) {
1071 if (!args[i])
1072 return default_bool;
1073 return toBool(i);
1074 }
1075
isNone(int i)1076 inline bool PythonArgs::isNone(int i) {
1077 return args[i] == nullptr;
1078 }
1079
generator(int i)1080 inline std::optional<at::Generator> PythonArgs::generator(int i) {
1081 if (!args[i])
1082 return std::nullopt;
1083 return reinterpret_cast<THPGenerator*>(args[i])->cdata;
1084 }
1085
storage(int i)1086 inline at::Storage PythonArgs::storage(int i) {
1087 if (!args[i])
1088 return at::Storage();
1089 return createStorage(args[i]);
1090 }
1091
storage(int i,at::ScalarType & storage_scalar_type,bool & is_typed_storage)1092 inline at::Storage PythonArgs::storage(
1093 int i,
1094 at::ScalarType& storage_scalar_type,
1095 bool& is_typed_storage) {
1096 at::Storage storage;
1097 if (!args[i]) {
1098 storage = at::Storage();
1099 is_typed_storage = false;
1100 storage_scalar_type = at::ScalarType::Undefined;
1101 } else {
1102 std::tie(storage, storage_scalar_type, is_typed_storage) =
1103 createStorageGetType(args[i]);
1104 }
1105 return storage;
1106 }
1107
stream(int i)1108 inline c10::Stream PythonArgs::stream(int i) {
1109 if (!args[i])
1110 return c10::Stream(
1111 c10::Stream::Default::DEFAULT, c10::Device(c10::DeviceType::CPU, -1));
1112 if (!THPStream_Check(args[i])) {
1113 throw TypeError(
1114 "expected Stream object. Got '%s'", Py_TYPE(args[i])->tp_name);
1115 }
1116 return c10::Stream::unpack3(
1117 ((THPStream*)args[i])->stream_id,
1118 static_cast<c10::DeviceIndex>(((THPStream*)args[i])->device_index),
1119 static_cast<c10::DeviceType>(((THPStream*)args[i])->device_type));
1120 }
1121
pyobject(int i)1122 inline PyObject* PythonArgs::pyobject(int i) {
1123 if (!args[i])
1124 return Py_None;
1125 return args[i];
1126 }
1127
1128 /*
1129 *
1130 * Handle __torch_function__ overrides if we know that there are overloaded
1131 * arguments. All objects stored in r.overloaded_args must have a
1132 * __torch_function__ implementation and the arguments must be ordered in order
1133 * of precedence. Precedence goes from left to right in the order of the
1134 * signature of the function the overloaded arguments were passed to, except
1135 * subclasses are always considered before superclasses.
1136 *
1137 * If the result of calling __torch_function__ is NotImplemented, the
1138 * next implementation in the precedence order is called. If all
1139 * arguments return NotImplemented from their __torch_function__
1140 * implementation, a TypeError is raised in Python.
1141 *
1142 * Assumes overloaded_args has at least one entry. All entries must have
1143 * a __torch_function__ attribute that resolves to a callable that
1144 * accepts a torch API function, a tuple of arguments, and a dict of
1145 * keyword arguments for the torch API function.
1146 *
1147 * It is sufficient to call PythonArgs::has_torch_function before
1148 * calling this function to verify that there are valid arguments
1149 * present. If that is not done then special care must be taken to
1150 * ensure there are arguments that are overloaded with
1151 * __torch_function__.
1152 *
1153 * See torch._overrides.handle_torch_function for the equivalent
1154 * code in the pure-python implementation.
1155 *
1156 * 'r' is a parsed PythonArgs instance, returned from
1157 * PythonArgParser::parse.
1158 *
1159 * 'args' is a reference to the python tuple of arguments to the torch
1160 * API function.
1161 *
1162 * 'kwargs' is a reference to the python dict of keyword arguments to
1163 * the torch API function.
1164 *
1165 * 'torch_api' is a reference to a python torch API namespace.
1166 *
1167 * 'torch_api_function' is the reference to the original torch method, usually,
1168 * we can use torch_api and func_name to get torch_api_function. In some cases,
1169 * e.g., torch custom op, we create the function in C++, if we still use
1170 * torch_api and func_name to fetch original api, a cyclic call will happen.
1171 *
1172 * 'overloaded_args' is the args which have overloaded __torch_function__.
1173 *
1174 * 'func_name' is the named of the original torch method.
1175 *
1176 * TODO: we could use different names for the following 'handle_torch_function'
1177 * instead of overloading.
1178 *
1179 */
1180 // Used for Tensor methods with arguments.
1181 auto handle_torch_function(
1182 PythonArgs& r,
1183 PyObject* self,
1184 PyObject* args,
1185 PyObject* kwargs,
1186 PyObject* torch_api,
1187 const char* module_name,
1188 const char* func_name_override = nullptr) -> PyObject*;
1189
1190 // Used for functions which needs to parse python args.
1191 auto handle_torch_function(
1192 PythonArgs& r,
1193 PyObject* args,
1194 PyObject* kwargs,
1195 PyObject* torch_api,
1196 const char* module_name,
1197 const char* func_name_override = nullptr) -> PyObject*;
1198
1199 // Used for functions that have no argument parsing.
1200 auto handle_torch_function(
1201 PyObject* self,
1202 const std::string& func_name,
1203 PyObject* args = nullptr,
1204 PyObject* kwargs = nullptr,
1205 PyObject* torch_api = THPVariableClass,
1206 const std::string& module_name = "torch.Tensor") -> PyObject*;
1207
1208 // Used for functions created in C++, e.g., C++ custom op, which doesn't use
1209 // PythonArgParser to get overloaded_args.
1210 enum class TorchFunctionName { TorchFunction, TorchDispatch };
1211
1212 auto TORCH_PYTHON_API handle_torch_function_no_python_arg_parser(
1213 at::ArrayRef<PyObject*> overloaded_args,
1214 PyObject* args,
1215 PyObject* kwargs,
1216 const char* func_name,
1217 PyObject* torch_api_function,
1218 const char* module_name,
1219 TorchFunctionName torch_function_name = TorchFunctionName::TorchFunction)
1220 -> PyObject*;
1221
1222 // Used for getters of Tensor properties
1223 auto handle_torch_function_getter(
1224 THPVariable* self,
1225 const std::string& property_name) -> PyObject*;
1226
1227 // Used for setters of Tensor properties.
1228 auto handle_torch_function_setter(
1229 THPVariable* self,
1230 const std::string& property_name,
1231 PyObject* value) -> int;
1232
1233 // Used for __getitem__ and __setitem__
1234 auto handle_torch_function_indexing(
1235 PyObject* self,
1236 PyObject* index,
1237 PyObject* val = nullptr) -> PyObject*;
1238
1239 /*
1240 * Check if the input obj is Tensor type, including its subclass, or overloaded
1241 * type. If the type defines __torch_function__, it also returns true.
1242 * Otherwise returns flase. If the class is not torch.Tensor, and it defines
1243 * __torch_function__, we append obj to overloaded_args.
1244 *
1245 * 'obj': the input argument to be checked
1246 * 'overloaded_args': the vector to append the overloaded args.
1247 */
1248 bool is_tensor_and_append_overloaded(
1249 PyObject* obj,
1250 std::vector<PyObject*>* overloaded_args);
1251
1252 /*
1253 * Check if the input obj is Tensor List or Tensor Tuple type. First check
1254 * whether obj is Tuple or List type, if true, iterate over each element and
1255 * check whether it is Tensor type, including its subclass or overloaded type.
1256 * At the same time, the overloaded arg is appended to the overloaded_args.
1257 *
1258 * 'obj': the input argument to be checked
1259 * 'overloaded_args': the vector to append the overloaded args.
1260 * 'argnum': the number of total arguments of the function being checked.
1261 * 'throw_error': whether throw error if any element in the list or tuple is
1262 * not tensor type or overloaded.
1263 */
1264 bool is_tensor_list_and_append_overloaded(
1265 PyObject* obj,
1266 std::vector<PyObject*>* overloaded_args,
1267 size_t argnum,
1268 bool throw_error);
1269
1270 /* Given an argument that is definitely a tensor and is definitely overloaded,
1271 * append it to the overloaded arguments list. Use this instead of
1272 * is_tensor_and_append_overloaded in situations where you have a PyObject
1273 * and you know it definitely is a Tensor and it is definitely overloaded.
1274 *
1275 * 'overloaded_args': the vector to append the overloaded args
1276 * 'obj': the input tensor that is overloaded
1277 */
1278 void append_overloaded_tensor(
1279 std::vector<PyObject*>* overloaded_args,
1280 PyObject* obj);
1281
1282 /* Given an argument that is definitely a type and is definitely overloaded,
1283 * append it to the overloaded arguments list. Use this only with
1284 * __torch_dispatch__, where we operate on classes that have a
1285 * __torch_dispatch__ classmethod.
1286 *
1287 * 'overloaded_args': the vector to append the overloaded type
1288 * 'obj': the input class that has a __torch_dispatch__ classmethod.
1289 */
1290 void append_overloaded_type(
1291 std::vector<PyObject*>* overloaded_args,
1292 PyObject* obj);
1293
1294 } // namespace torch
1295