1 #pragma once
2
3 #include <ATen/core/ivalue.h>
4 #include <ATen/core/jit_type.h>
5 #include <ATen/core/qualified_name.h>
6 #include <ATen/core/stack.h>
7 #include <pybind11/complex.h>
8 #include <pybind11/pybind11.h>
9 #include <pybind11/pytypes.h>
10 #include <torch/csrc/Device.h>
11 #include <torch/csrc/Dtype.h>
12 #include <torch/csrc/Export.h>
13 #include <torch/csrc/Layout.h>
14 #include <torch/csrc/QScheme.h>
15 #include <torch/csrc/Stream.h>
16 #include <torch/csrc/jit/api/module.h>
17 #include <torch/csrc/jit/frontend/schema_matching.h>
18 #include <torch/csrc/jit/frontend/tracer.h>
19 #include <torch/csrc/jit/python/module_python.h>
20 #include <torch/csrc/jit/python/python_custom_class.h>
21 #include <torch/csrc/jit/python/python_tracer.h>
22 #include <torch/csrc/jit/resource_guard.h>
23 #include <torch/csrc/jit/runtime/operator.h>
24 #include <torch/csrc/utils/pybind.h>
25 #include <torch/csrc/utils/python_arg_parser.h>
26 #include <torch/csrc/utils/six.h>
27 #ifdef USE_DISTRIBUTED
28 #include <torch/csrc/distributed/rpc/py_rref.h>
29 #include <torch/csrc/distributed/rpc/rref_impl.h>
30 #endif
31
32 #include <ATen/core/function_schema.h>
33 #include <c10/core/Stream.h>
34 #include <c10/util/Exception.h>
35 #include <c10/util/irange.h>
36 #include <optional>
37
38 #include <algorithm>
39 #include <cstddef>
40 #include <string>
41 #include <utility>
42 #include <vector>
43
44 // The visibility attribute is to avoid a warning about storing a field in the
45 // struct that has a different visibility (from pybind) than the struct.
46 #ifdef _WIN32
47 #define VISIBILITY_HIDDEN
48 #else
49 #define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
50 #endif
51
52 namespace torch::jit {
53
54 using ResolutionCallback = std::function<py::object(std::string)>;
55
56 void clear_registered_instances(void* ptr);
57
58 TORCH_PYTHON_API IValue toIValue(
59 py::handle obj,
60 const TypePtr& type,
61 std::optional<int32_t> N = std::nullopt);
62
63 TORCH_PYTHON_API py::object toPyObject(IValue ivalue);
64
65 // Hack to overload the behavior of toIValue to accept Python
66 // numbers in places where a Tensor is expected
67 // See also torch::should_allow_numbers_as_tensors
68 class ToIValueAllowNumbersAsTensors {
69 bool old_;
70
71 public:
72 ToIValueAllowNumbersAsTensors(bool enable);
73 ~ToIValueAllowNumbersAsTensors();
74 };
75
76 // Wrap Python function to guard deref
77 // NB: Need VISIBILITY_HIDDEN for silencing compiler error,
78 // 'torch::jit::PythonFunctionGuard' declared with greater visibility than the
79 // type of its field 'torch::jit::PythonFunctionGuard::func_'
80 struct VISIBILITY_HIDDEN PythonFunctionGuard {
PythonFunctionGuardPythonFunctionGuard81 explicit PythonFunctionGuard(py::function func) : func_(std::move(func)) {}
82
~PythonFunctionGuardPythonFunctionGuard83 ~PythonFunctionGuard() {
84 pybind11::gil_scoped_acquire ag;
85 func_.dec_ref();
86 // explicitly setting PyObject* to nullptr to prevent py::object's dtor to
87 // decref on the PyObject again.
88 // See Note [Destructing py::object] in python_ivalue.h
89 func_.ptr() = nullptr;
90 }
91
92 py::function func_;
93 };
94
95 // The PythonFutureWrapper for ivalue::Future
96 //
97 // NB: VISIBILITY_HIDDEN is for silencing compiling error,
98 // "error: 'torch::jit::PythonFutureWrapper' declared with greater visibility
99 // than the type of its field 'torch::jit::PythonFutureWrapper::unwrap_func'
100 // [-Werror=attributes]"
101 //
102 // NB: inherit from enable_shared_from_this because then(py::function) needs to
103 // get a shared_ptr from this pointer.
104 struct VISIBILITY_HIDDEN PythonFutureWrapper
105 : std::enable_shared_from_this<PythonFutureWrapper> {
106 using UnwrapFunc = std::function<void(py::object)>;
107
108 explicit PythonFutureWrapper(
109 c10::intrusive_ptr<c10::ivalue::Future> fut,
110 std::optional<UnwrapFunc> unwrap_func = std::nullopt)
futPythonFutureWrapper111 : fut(std::move(fut)), unwrap_func(std::move(unwrap_func)) {}
112
113 explicit PythonFutureWrapper(const PythonFutureWrapper&) = delete;
114 PythonFutureWrapper& operator=(const PythonFutureWrapper&) = delete;
115
donePythonFutureWrapper116 bool done() {
117 return fut->completed();
118 }
119
valuePythonFutureWrapper120 py::object value() {
121 // acquiring GIL as toPyObject creates new py::object
122 // without grabbing the GIL.
123 py::gil_scoped_acquire acquire;
124 py::object py_obj = toPyObject(fut->value());
125 // unwrap_func is a general compositional function that takes in a
126 // py::object and executes some python function. It is currently mostly used
127 // to throw python exceptions.
128 if (unwrap_func) {
129 (*unwrap_func)(py_obj);
130 }
131 return py_obj;
132 }
133
waitPythonFutureWrapper134 py::object wait() {
135 fut->wait();
136 if (jit::tracer::isTracing()) {
137 auto graph = jit::tracer::getTracingState()->graph;
138
139 Value* fut_val = jit::tracer::getValueTrace(fut);
140 auto output = graph->insert(aten::wait, {fut_val});
141 jit::tracer::setValueTrace(fut->value(), output);
142 }
143 return value();
144 }
145
146 // The py::function cb arg must take a std::shared_ptr<PythonFutureWrapper>
147 // (i.e., torch._C.Future) as the only argument. If the type mismatches, an
148 // error will be thrown when waiting for the value of this returned Future.
thenPythonFutureWrapper149 std::shared_ptr<PythonFutureWrapper> then(py::function cb) {
150 // We need this an additional layer of wrapper here to guard the
151 // destruction of the py::function object. Because, the
152 // Future owns a reference to the py::function in its callback
153 // vector, but Future does not acquire GIL on destruction.
154 auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
155
156 return std::make_shared<jit::PythonFutureWrapper>(fut->then(
157 // Capture a copy of the ivalue::Future instead of the `this` pointer
158 // because the PythonFutureWrapper object could have been deleted
159 // when the callbacks are fired. For example, RPC only captures the
160 // ivalue::Future instead of PythonFutureWrapper in JitFuture's
161 // callback functions. Hence, if user code does not hold a reference to
162 // this PythonFutureWrapper object, there is no guarantee that the
163 // PythonFutureWrapper is still valid when running the callback.
164 [pyFut(this->getPtr()),
165 pf(std::move(pf))](c10::ivalue::Future& /* unused */) -> IValue {
166 try {
167 pybind11::gil_scoped_acquire ag;
168 return toIValue(pf->func_(pyFut), PyObjectType::get());
169 } catch (py::error_already_set& e) {
170 auto err = std::runtime_error(c10::str(
171 "Got the following error when running the callback: ",
172 e.what()));
173 {
174 pybind11::gil_scoped_acquire ag;
175 // Release ownership on py::objects and also restore Python
176 // Error Indicator.
177 e.restore();
178 // Clear the Python Error Indicator as we has recorded the
179 // exception in the response message.
180 PyErr_Clear();
181 }
182
183 throw std::runtime_error(err);
184 }
185 },
186 PyObjectType::get()));
187 }
188
add_done_callbackPythonFutureWrapper189 void add_done_callback(py::function cb) {
190 auto pf = std::make_shared<PythonFunctionGuard>(std::move(cb));
191 // NOLINTNEXTLINE(modernize-avoid-bind)
192 fut->addCallback(std::bind(
193 [pyFut(this->getPtr())](
194 const std::shared_ptr<PythonFunctionGuard>& pf) {
195 try {
196 pybind11::gil_scoped_acquire ag;
197 pf->func_(pyFut);
198 } catch (py::error_already_set& e) {
199 {
200 pybind11::gil_scoped_acquire ag;
201 // Release ownership on py::objects and also restore Python
202 // Error Indicator.
203 e.restore();
204 // Clear the Python Error Indicator as we has recorded the
205 // exception in the response message.
206 PyErr_Clear();
207 }
208 // Log and ignore exceptions raised through the callback
209 LOG(ERROR) << "Got the following error when running the callback: "
210 << e.what();
211
212 } catch (const std::exception& e) {
213 // Log and ignore exceptions raised through the callback
214 LOG(ERROR) << "Got the following error when running the callback: "
215 << e.what();
216 }
217 },
218 std::move(pf)));
219 }
220
markCompletedPythonFutureWrapper221 void markCompleted(const py::object& pyValue) {
222 DCHECK(PyGILState_Check());
223 IValue value = toIValue(pyValue, PyObjectType::get());
224
225 py::gil_scoped_release release;
226 fut->markCompleted(std::move(value));
227 }
228
229 c10::intrusive_ptr<c10::ivalue::Future> fut;
230 // unwrap_func works like a callback for the value returned by
231 // PythonFutureWrapper::wait().
232 std::optional<UnwrapFunc> unwrap_func;
233
234 private:
getPtrPythonFutureWrapper235 std::shared_ptr<PythonFutureWrapper> getPtr() {
236 return shared_from_this();
237 }
238 };
239
240 // The PythonAwaitWrapper for ivalue::Await
241 //
242 // Expresses delayed function execution with Lazy semantic.
243 // i.e. Await[W] in eager mode can be used as W.
244 // When the attribute of W type is requested, Await[W] will return the
245 // attribute of W, transparently calling wait() beforehand.
246 // No Lazy semantic for script, explicit wait(Await[W]) -> W must be called to
247 // convert to type W.
248 //
249 // The Await object takes shared ownership of specified function and the
250 // arguments. After first call for wait() it owns the result. Deliberately no
251 // type inference for eager mode.
252 struct VISIBILITY_HIDDEN PythonAwaitWrapper
253 : std::enable_shared_from_this<PythonAwaitWrapper> {
PythonAwaitWrapperPythonAwaitWrapper254 explicit PythonAwaitWrapper(c10::intrusive_ptr<c10::ivalue::Await> aw)
255 : aw_(std::move(aw)) {}
PythonAwaitWrapperPythonAwaitWrapper256 explicit PythonAwaitWrapper(py::handle input) {
257 args_ = py::tuple(1u);
258 args_[0] = input;
259 auto type = PyObjectType::get();
260 aw_ = c10::make_intrusive<c10::ivalue::Await>(type);
261 aw_->markCompleted(toIValue(input, type));
262 }
263
PythonAwaitWrapperPythonAwaitWrapper264 explicit PythonAwaitWrapper(py::function pf, py::tuple args)
265 : args_(std::move(args)) {
266 pyfg_ = std::make_shared<torch::jit::PythonFunctionGuard>(std::move(pf));
267
268 std::function<IValue()> f = [fg(pyfg_), &args(args_)]() {
269 pybind11::gil_scoped_acquire ag;
270 return toIValue(fg->func_(*args), PyObjectType::get());
271 };
272 aw_ = c10::make_intrusive<c10::ivalue::Await>(
273 PyObjectType::get(), std::move(f));
274 }
275
276 explicit PythonAwaitWrapper(const PythonAwaitWrapper&) = delete;
277 PythonAwaitWrapper& operator=(const PythonAwaitWrapper&) = delete;
278
waitPythonAwaitWrapper279 py::object wait() {
280 py::gil_scoped_acquire acquire;
281 return toPyObject(aw_->wait());
282 }
283
284 // Nowait semantic means trivial case when Await is constructed from the
285 // result
is_nowaitPythonAwaitWrapper286 bool is_nowait() {
287 return pyfg_ == nullptr;
288 }
289
fnPythonAwaitWrapper290 const py::function fn() {
291 TORCH_CHECK(
292 pyfg_, "Await constructed as awaitable_nowait does not have fn");
293 return pyfg_->func_;
294 }
295
argsPythonAwaitWrapper296 const py::tuple args() {
297 return args_;
298 }
299
typePythonAwaitWrapper300 TypePtr type() {
301 return aw_->type();
302 }
303
304 c10::intrusive_ptr<c10::ivalue::Await> aw_;
305 std::shared_ptr<torch::jit::PythonFunctionGuard> pyfg_;
306 py::tuple args_;
307
308 private:
getPtrPythonAwaitWrapper309 std::shared_ptr<PythonAwaitWrapper> getPtr() {
310 return shared_from_this();
311 }
312 };
313
314 // error reporting: when reporting user-caused errors, these functions should
315 // not use AT_ERROR macros, since these macros add stack trace information
316 // that is confusing to display to the end user since it always reports
317 // locations in libtorch code rather than user code.
318
get_python_cu()319 inline std::shared_ptr<CompilationUnit> get_python_cu() {
320 return py::module::import("torch.jit._state")
321 .attr("_python_cu")
322 .cast<std::shared_ptr<CompilationUnit>>();
323 }
324
325 struct TypedIValue : public std::pair<IValue, TypePtr> {
326 using pair::pair;
327
ivalueTypedIValue328 IValue& ivalue() {
329 return this->first;
330 }
typeTypedIValue331 TypePtr& type() {
332 return this->second;
333 }
334 };
335
toDictKeyIValue(py::handle key)336 inline TypedIValue toDictKeyIValue(py::handle key) {
337 if (py::isinstance<py::str>(key)) {
338 return TypedIValue(
339 ConstantString::create(py::cast<std::string>(key)), StringType::get());
340 } else if (py::isinstance<py::int_>(key)) {
341 return TypedIValue(py::cast<int64_t>(key), IntType::get());
342 } else if (py::isinstance<py::float_>(key)) {
343 return TypedIValue(py::cast<double>(key), FloatType::get());
344 } else {
345 AT_ERROR("Dictionary inputs may only have string, int, or float keys");
346 }
347 }
348
unifyOrInitializeType(const TypePtr & accum,const TypePtr & unify)349 inline std::optional<TypePtr> unifyOrInitializeType(
350 const TypePtr& accum,
351 const TypePtr& unify) {
352 if (!accum) {
353 return unify;
354 }
355 return unifyTypes(accum, unify);
356 }
357
358 using InferredType = c10::InferredType;
359
360 InferredType tryToInferContainerType(py::handle input, bool primitiveTypeOnly);
361
362 // Try to infer the type of a Python object
363 // The type cannot be inferred if:
364 // input is an empty container (list, dict)
365 // input is an list with element types that cannot be unified
366 // input is an dict with key or value types that cannot be unified
tryToInferType(py::handle input)367 inline InferredType tryToInferType(py::handle input) {
368 // Try tensor types
369 if (THPVariable_Check(input.ptr())) {
370 return InferredType(TensorType::get());
371 }
372
373 if (input.is_none()) {
374 return InferredType(NoneType::get());
375 }
376
377 if (py::isinstance<StrongFunctionPtr>(input)) {
378 auto fn = py::cast<StrongFunctionPtr>(input).function_;
379 return InferredType(FunctionType::create(fn));
380 }
381
382 // Try basic types first
383 if (py::isinstance<py::bool_>(input)) {
384 return InferredType(BoolType::get());
385 // NOLINTNEXTLINE(bugprone-branch-clone)
386 } else if (py::isinstance<py::int_>(input)) {
387 return InferredType(IntType::get());
388 } else if (py::isinstance<py::float_>(input)) {
389 return InferredType(FloatType::get());
390 } else if (PyComplex_CheckExact(input.ptr())) {
391 return InferredType(ComplexType::get());
392 } else if (py::isinstance<py::str>(input)) {
393 return InferredType(StringType::get());
394 } else if (THPLayout_Check(input.ptr())) {
395 return InferredType(IntType::get());
396 } else if (THPDevice_Check(input.ptr())) {
397 return InferredType(DeviceObjType::get());
398 } else if (THPGenerator_Check(input.ptr())) {
399 return InferredType(GeneratorType::get());
400 } else if (THPStream_Check(input.ptr())) {
401 return InferredType(StreamObjType::get());
402 } else if (THPDtype_Check(input.ptr())) {
403 return InferredType(IntType::get());
404 } else if (THPQScheme_Check(input.ptr())) {
405 return InferredType(IntType::get());
406 } else if (THPLayout_Check(input.ptr())) {
407 return InferredType(IntType::get());
408 }
409
410 auto enum_type = py::module::import("enum").attr("Enum");
411 py::bool_ isEnumValue = py::isinstance(input, enum_type);
412 if (py::cast<bool>(isEnumValue)) {
413 auto enum_class = input.attr("__class__");
414 auto enum_type = py::cast<TypePtr>(
415 py::module::import("torch.jit.annotations")
416 .attr("try_ann_to_type")(enum_class, SourceRange()));
417 return InferredType(std::move(enum_type));
418 }
419
420 py::bool_ isClass =
421 py::module::import("inspect").attr("isclass")(input.get_type());
422 if (py::cast<bool>(isClass)) {
423 // Assume that the class is compiled already or will compile. Invalidate
424 // this later if needed.
425 bool class_compiled = true;
426
427 // Check if the type is already compiled.
428 py::object existing_ty = py::module::import("torch.jit._state")
429 .attr("_get_script_class")(input.get_type());
430
431 if (existing_ty.is_none()) {
432 // If not, try to compile it.
433 py::bool_ can_compile = py::module::import("torch._jit_internal")
434 .attr("can_compile_class")(input.get_type());
435
436 if (py::cast<bool>(can_compile)) {
437 // Try to compile the class. This is wrapped in a try-catch because
438 // compilation of class types can raise an Exception and in that case,
439 // we want to defer to other attempts at type inference below rather
440 // than fail compilation altogether.
441 try {
442 py::module::import("torch.jit._script")
443 .attr("_recursive_compile_class")(
444 input.get_type(), SourceRange());
445 } catch (...) {
446 // Invalidate the assumption that the class compiled so that we don't
447 // look up and return its JIT type as the type for the input.
448 class_compiled = false;
449 }
450 }
451 }
452
453 // If the class compiled successfully, look up the existing JIT type by
454 // qualified name and return it.
455 if (class_compiled) {
456 auto script_class = py::module::import("torch.jit._state")
457 .attr("_get_script_class")(input.get_type());
458
459 if (!script_class.is_none()) {
460 auto class_type = py::cast<ClassTypePtr>(script_class);
461
462 if (class_type && !class_type->is_module()) {
463 return InferredType(std::move(class_type));
464 }
465 }
466 }
467 }
468
469 if (py::isinstance<Object>(input)) {
470 auto object = py::cast<Object>(input);
471 return InferredType(object.type());
472 #ifdef USE_RPC
473 } else if (py::isinstance<torch::distributed::rpc::PyRRef>(input)) {
474 auto rref_ivalue = input.cast<torch::distributed::rpc::PyRRef>().toIValue();
475 return InferredType(rref_ivalue.type());
476 #endif
477 }
478
479 auto await_type = py::module::import("torch._awaits").attr("_Await");
480 py::bool_ is_await = py::isinstance(input, await_type);
481 if (py::cast<bool>(is_await)) {
482 auto awptr = input.cast<std::shared_ptr<PythonAwaitWrapper>>();
483 return InferredType(AwaitType::create(awptr->aw_->elementType()));
484 }
485
486 if (as_module(py::cast<py::object>(input))) {
487 return InferredType("Cannot infer type of ScriptModule");
488 }
489
490 auto module_type = py::module::import("torch.nn").attr("Module");
491 py::bool_ is_module = py::isinstance(input, module_type);
492 if (py::cast<bool>(is_module)) {
493 return InferredType("Cannot infer concrete type of torch.nn.Module");
494 }
495
496 // Try container types
497 return tryToInferContainerType(input, false);
498 }
499
500 // This function is similar to tryToInferType, but it only tries to infer
501 // primitive types (int, float, bool, complex) or nested container of primitive
502 // types.
tryToInferPrimitiveType(py::handle input)503 inline InferredType tryToInferPrimitiveType(py::handle input) {
504 if (input.is_none()) {
505 return InferredType(NoneType::get());
506 }
507
508 // Only primitive data type
509 if (py::isinstance<py::bool_>(input)) {
510 return InferredType(BoolType::get());
511 // NOLINTNEXTLINE(bugprone-branch-clone)
512 } else if (py::isinstance<py::int_>(input)) {
513 return InferredType(IntType::get());
514 } else if (py::isinstance<py::float_>(input)) {
515 return InferredType(FloatType::get());
516 } else if (PyComplex_CheckExact(input.ptr())) {
517 return InferredType(ComplexType::get());
518 }
519
520 // Try container types
521 return tryToInferContainerType(input, true);
522 }
523
524 inline InferredType tryToInferContainerType(
525 py::handle input,
526 bool primitiveTypeOnly = false) {
527 if (six::isTuple(input)) {
528 py::tuple tuple = py::cast<py::tuple>(input);
529 std::vector<TypePtr> element_types;
530 element_types.reserve(tuple.size());
531
532 for (py::handle elem : tuple) {
533 auto type_match = primitiveTypeOnly ? tryToInferPrimitiveType(elem)
534 : tryToInferType(elem);
535 if (type_match.success()) {
536 element_types.push_back(type_match.type());
537 } else {
538 // Forward error message along
539 return type_match.reason();
540 }
541 }
542 return InferredType(TupleType::create(std::move(element_types)));
543 } else if (PyDict_Check(input.ptr())) {
544 // Check to make sure we can generate useful input/output types
545 auto dict = py::cast<py::dict>(input);
546 size_t len = py::len(dict);
547 if (!len) {
548 return InferredType("Dictionary inputs must have entries");
549 }
550
551 TypePtr key_type = nullptr;
552 TypePtr value_type = nullptr;
553
554 for (auto entry : dict) {
555 // Try to infer the key type and unify it with the existing one
556 auto entry_key_type_match = primitiveTypeOnly
557 ? tryToInferPrimitiveType(entry.first)
558 : tryToInferType(entry.first);
559 if (!entry_key_type_match.success()) {
560 return entry_key_type_match.reason();
561 }
562 auto unified_key =
563 unifyOrInitializeType(key_type, entry_key_type_match.type());
564 if (!unified_key) {
565 return InferredType(c10::str(
566 "Dictionary inputs to traced functions must have consistent type. Found ",
567 key_type->repr_str(),
568 " and ",
569 (entry_key_type_match.type())->repr_str()));
570 }
571
572 // Try to infer the value type and unify it with the existing one
573 auto entry_value_type_match = primitiveTypeOnly
574 ? tryToInferPrimitiveType(entry.second)
575 : tryToInferType(entry.second);
576 if (!entry_value_type_match.success()) {
577 return entry_value_type_match.reason();
578 }
579 auto unified_value =
580 unifyOrInitializeType(value_type, entry_value_type_match.type());
581 if (!unified_value) {
582 return InferredType(c10::str(
583 "Dictionary inputs to traced functions must have consistent type. Found ",
584 value_type->repr_str(),
585 " and ",
586 (entry_value_type_match.type())->repr_str()));
587 }
588
589 key_type = *unified_key;
590 value_type = *unified_value;
591 }
592 return InferredType(
593 DictType::create(std::move(key_type), std::move(value_type)));
594 } else if (PyList_Check(input.ptr())) {
595 auto list = py::cast<py::list>(input);
596 size_t len = py::len(list);
597 if (!len) {
598 return InferredType("List trace inputs must have elements");
599 }
600
601 TypePtr element_type = nullptr;
602 for (auto elem : list) {
603 auto element_type_match = primitiveTypeOnly
604 ? tryToInferPrimitiveType(elem)
605 : tryToInferType(elem);
606 if (!element_type_match.success()) {
607 return InferredType(c10::str(
608 "Could not infer type of list element: ",
609 element_type_match.reason()));
610 }
611 auto unified_type =
612 unifyOrInitializeType(element_type, element_type_match.type());
613 if (!unified_type) {
614 return InferredType(c10::str(
615 "List inputs to traced functions must have consistent element type. Found ",
616 element_type->repr_str(),
617 " and ",
618 (element_type_match.type())->repr_str()));
619 }
620 element_type = *unified_type;
621 }
622 return InferredType(ListType::create(element_type));
623 } else {
624 if (primitiveTypeOnly) {
625 return InferredType(c10::str(
626 "Only tuple, list, or dict (possibly nested) of primitive types (bool, float, int, complex)",
627 "are supported ",
628 "as inputs or outputs of traced functions",
629 ", but instead got value of type ",
630 py::str(input.get_type().attr("__name__")),
631 "."));
632 } else {
633 // TODO: this message is not correct anymore, since this InferredType is
634 // used from a bunch of circumstances unrelated to tracing. We can re-use
635 // this instead of the attribute_failure stuff in concreteType
636 return InferredType(c10::str(
637 "Only tensors and (possibly nested) tuples of tensors, lists, or dicts",
638 "are supported ",
639 "as inputs or outputs of traced functions",
640 ", but instead got value of type ",
641 py::str(input.get_type().attr("__name__")),
642 "."));
643 }
644 }
645 }
646
isTraceableType(const TypePtr & type)647 inline bool isTraceableType(const TypePtr& type) {
648 if (type->isSubtypeOf(*TensorType::get())) {
649 return true;
650 }
651
652 if (auto list_type = type->cast<ListType>()) {
653 return isTraceableType(list_type->getElementType());
654 }
655
656 if (auto tuple_type = type->cast<TupleType>()) {
657 return std::all_of(
658 tuple_type->elements().begin(),
659 tuple_type->elements().end(),
660 [](const TypePtr& element_type) {
661 return isTraceableType(element_type);
662 });
663 }
664
665 if (auto dict_type = type->cast<DictType>()) {
666 return isTraceableType(dict_type->getValueType());
667 }
668
669 return false;
670 }
671
toTypeInferredIValue(py::handle input)672 inline IValue toTypeInferredIValue(py::handle input) {
673 auto match = tryToInferType(input);
674 if (!match.success()) {
675 auto object = py::cast<py::object>(input);
676 if (auto mod = as_module(object)) {
677 // if obj is already a ScriptModule, just return its ivalue
678 auto ptr = mod.value()._ivalue();
679 // explict copy semantics for strong ownership of the resource.
680 return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy(
681 ptr.release());
682 }
683
684 // Check if the obj is a ScriptObject.
685 if (auto script_obj = as_object(object)) {
686 auto ptr = script_obj.value()._ivalue();
687 return c10::intrusive_ptr<c10::ivalue::Object>::reclaim_copy(
688 ptr.release());
689 }
690 AT_ERROR(
691 "Tracer cannot infer type of ", py::str(input), "\n:", match.reason());
692 }
693 return toIValue(input, match.type());
694 }
695
toTraceableStack(const py::tuple & inputs)696 inline Stack toTraceableStack(const py::tuple& inputs) {
697 auto info = toTypeInferredIValue(inputs);
698 TORCH_CHECK(
699 isTraceableType(info.type()),
700 "Type '",
701 info.type()->repr_str(),
702 "' cannot be traced. Only Tensors and (possibly nested) Lists, Dicts, and"
703 " Tuples of Tensors can be traced");
704 return info.toTupleRef().elements().vec();
705 }
706
707 // Serialize the python dictionary into a traceable stack.
toTraceableStack(const py::dict & inputs)708 inline Stack toTraceableStack(const py::dict& inputs) {
709 Stack res;
710 for (auto it = inputs.begin(); it != inputs.end(); it++) {
711 if (THPVariable_Check(it->second.ptr())) {
712 res.push_back(toIValue(it->second, tryToInferType(it->second).type()));
713 }
714 }
715 return res;
716 }
717
createGenericList(py::handle obj,const TypePtr & elem_type)718 inline IValue createGenericList(py::handle obj, const TypePtr& elem_type) {
719 auto elems = c10::impl::GenericList(elem_type);
720 for (auto elem : obj) {
721 elems.push_back(toIValue(elem, elem_type));
722 }
723 return IValue(elems);
724 }
725
createGenericDict(const py::dict & obj,const TypePtr & key_type,const TypePtr & value_type)726 inline IValue createGenericDict(
727 const py::dict& obj,
728 const TypePtr& key_type,
729 const TypePtr& value_type) {
730 c10::impl::GenericDict elems(key_type, value_type);
731 elems.reserve(py::len(obj));
732 for (auto& entry : obj) {
733 elems.insert(
734 toIValue(entry.first, key_type), toIValue(entry.second, value_type));
735 }
736 return IValue(elems);
737 }
738
739 template <class T>
guardAgainstNamedTensor(const T & var)740 inline void guardAgainstNamedTensor(const T& var) {
741 TORCH_CHECK(
742 !var.has_names(),
743 "NYI: Named tensors are currently unsupported in TorchScript. As a "
744 "workaround please drop names via `tensor = tensor.rename(None)`.");
745 }
746
747 // Extract custom class registered with torchbind
748 template <typename T>
toCustomClass(py::handle obj)749 c10::intrusive_ptr<T> toCustomClass(py::handle obj) {
750 static_assert(
751 std::is_base_of_v<CustomClassHolder, T>, "T is not a CustomClass");
752 const auto& type = c10::getCustomClassType<c10::intrusive_ptr<T>>();
753 c10::IValue ivalue = toIValue(obj, type);
754 return std::move(ivalue).toCustomClass<T>();
755 }
756
757 // Small wrapper around getting the type name string from Python to make
758 // types easier to interpret, e.g. give the structural type for a NamedTuple
friendlyTypeName(py::handle obj)759 inline std::string friendlyTypeName(py::handle obj) {
760 if (py::isinstance<py::tuple>(obj) && py::hasattr(obj, "_fields")) {
761 auto field_names =
762 py::cast<std::vector<std::string>>(py::getattr(obj, "_fields"));
763 std::stringstream ss;
764 ss << py::str(obj.get_type().attr("__name__"));
765 ss << " (aka NamedTuple(";
766 bool first = true;
767 for (auto& field_name : field_names) {
768 if (!first) {
769 ss << ", ";
770 }
771 ss << field_name;
772 first = false;
773 }
774 ss << "))";
775 return ss.str();
776 } else {
777 return py::str(obj.get_type().attr("__name__"));
778 }
779 }
780
781 // Thrown when trying to create a schema for a list of python
782 // arguments that cannot be converted.
783 // Can be caught by the caller to attempt to use other schema
784 // when there is an overloaded operator.
785 struct schema_match_error : public std::runtime_error {
786 using std::runtime_error::runtime_error;
787 };
788
argumentToIValue(const FunctionSchema & schema,size_t argumentPosition,py::handle object)789 inline IValue argumentToIValue(
790 const FunctionSchema& schema,
791 size_t argumentPosition,
792 py::handle object) {
793 const auto& argument = schema.arguments().at(argumentPosition);
794 try {
795 return toIValue(object, argument.real_type(), argument.N());
796 } catch (const py::cast_error& error) {
797 throw schema_match_error(c10::str(
798 schema.formatTypeMismatchMsg(
799 argument,
800 friendlyTypeName(object),
801 argumentPosition,
802 py::repr(object)),
803 "\nCast error details: ",
804 error.what()));
805 } catch (const py::error_already_set& error) {
806 throw schema_match_error(c10::str(
807 schema.formatTypeMismatchMsg(
808 argument,
809 friendlyTypeName(object),
810 argumentPosition,
811 py::repr(object)),
812 "\n Python error details: ",
813 error.what()));
814 }
815 }
816
returnToIValue(const TypePtr & type,py::handle object)817 inline IValue returnToIValue(const TypePtr& type, py::handle object) {
818 try {
819 return toIValue(object, type);
820 } catch (const py::cast_error& error) {
821 throw std::runtime_error(c10::str(
822 " expected value of type ",
823 type->str(),
824 " for return value but instead got value of type ",
825 py::str(object.get_type().attr("__name__")),
826 ".",
827 "\nValue: ",
828 py::repr(object),
829 "\nCast error details: ",
830 error.what()));
831 }
832 }
833
getScriptedClassOrError(const c10::NamedTypePtr & classType)834 inline py::object getScriptedClassOrError(const c10::NamedTypePtr& classType) {
835 auto py_class =
836 py::module::import("torch.jit._state")
837 .attr("_get_python_class")(classType->name()->qualifiedName());
838 if (py_class.is_none()) {
839 std::stringstream err;
840 err << "Unknown reference to ScriptClass ";
841 err << classType->name()->qualifiedName();
842 err << ". (Did you forget to import it?)";
843 throw std::runtime_error(err.str());
844 }
845 return py_class;
846 }
847
848 struct VISIBILITY_HIDDEN tuple_slice {
tuple_slicetuple_slice849 /*implicit*/ tuple_slice(py::tuple tup_)
850 : tup(std::move(tup_)), b(0), e(tup.size()) {}
tuple_slicetuple_slice851 tuple_slice(py::tuple tup_, int64_t b_)
852 : tup(std::move(tup_)), b(b_), e(tup.size()) {}
tuple_slicetuple_slice853 tuple_slice(py::tuple tup_, int64_t b_, int64_t e_)
854 : tup(std::move(tup_)), b(b_), e(e_) {}
begintuple_slice855 py::detail::tuple_iterator begin() const {
856 return {tup, static_cast<pybind11::ssize_t>(b)};
857 }
endtuple_slice858 py::detail::tuple_iterator end() const {
859 return {tup, static_cast<pybind11::ssize_t>(e)};
860 }
sizetuple_slice861 size_t size() const {
862 return e - b;
863 }
864 py::detail::tuple_accessor operator[](size_t index) const {
865 return {tup, static_cast<size_t>(b + index)};
866 }
867
868 private:
869 py::tuple tup;
870 int64_t b;
871 int64_t e;
872 };
873
validateFakeScriptObjectSchema(const c10::FunctionSchema & schema,size_t argumentPosition,py::handle object)874 inline bool validateFakeScriptObjectSchema(
875 const c10::FunctionSchema& schema,
876 size_t argumentPosition,
877 py::handle object) {
878 auto argument = schema.arguments().at(argumentPosition);
879 auto class_type = argument.real_type()->expect<c10::ClassType>();
880 auto fake_class_registry =
881 py::module::import("torch._library.fake_class_registry");
882 auto fake_class = fake_class_registry.attr("find_fake_class")(
883 class_type->name().value().qualifiedName());
884 if (!py::isinstance(object.attr("wrapped_obj"), fake_class)) {
885 throw schema_match_error(c10::str(
886 schema.formatTypeMismatchMsg(
887 argument,
888 friendlyTypeName(object),
889 argumentPosition,
890 py::repr(object.attr("wrapped_obj"))),
891 "\nCast error details: ",
892 argument.name(),
893 " is expected to be a FakeScriptObject of ",
894 class_type->name().value().qualifiedName()));
895 }
896 return true;
897 }
898
matchSchemaAllowFakeScriptObject(const FunctionSchema & schema,const tuple_slice & args,const py::kwargs & kwargs)899 inline bool matchSchemaAllowFakeScriptObject(
900 const FunctionSchema& schema,
901 const tuple_slice& args,
902 const py::kwargs& kwargs) {
903 size_t all_arguments = args.size() + kwargs.size();
904 if (all_arguments > schema.arguments().size()) {
905 throw schema_match_error(c10::str(
906 schema.name(),
907 "() expected at most ",
908 schema.arguments().size(),
909 " argument(s) but received ",
910 all_arguments,
911 " argument(s). Declaration: ",
912 schema));
913 }
914
915 int64_t arg_idx = 0;
916 auto fake_class_registry =
917 py::module::import("torch._library.fake_class_registry");
918
919 // First push all positional args.
920 for (const auto& arg : args) {
921 // ...but refuse to do it if the schema says that this was supposed
922 // to be keyword only
923 if (schema.arguments()[arg_idx].kwarg_only()) {
924 throw schema_match_error(c10::str(
925 schema.name(),
926 "() takes ",
927 arg_idx,
928 " positional argument(s) but ",
929 args.size(),
930 " was/were given. Declaration: ",
931 schema));
932 }
933 // Use the type information from the schema to convert the PyObject.
934 const auto& argument = schema.arguments().at(arg_idx);
935 if (argument.real_type()->kind() == TypeKind::ClassType &&
936 py::isinstance(arg, fake_class_registry.attr("FakeScriptObject"))) {
937 validateFakeScriptObjectSchema(schema, arg_idx, arg);
938 } else {
939 argumentToIValue(schema, arg_idx, arg);
940 }
941
942 arg_idx++;
943 }
944
945 // Now for every remaining non-positional argument in the schema, look for it
946 // in the kwargs dict and push it if found, or use its default value if it
947 // has one.
948 size_t consumed_kwargs = 0;
949 for (size_t i = arg_idx; i < schema.arguments().size(); ++i) {
950 const auto& arg = schema.arguments()[i];
951 if (kwargs.contains(arg.name().c_str())) {
952 auto cur_kwarg = kwargs[arg.name().c_str()];
953 if (arg.real_type()->kind() == TypeKind::ClassType &&
954 py::isinstance(
955 cur_kwarg, fake_class_registry.attr("FakeScriptObject"))) {
956 validateFakeScriptObjectSchema(schema, i, cur_kwarg);
957 } else {
958 argumentToIValue(schema, i, cur_kwarg);
959 }
960 consumed_kwargs += 1;
961 } else if (arg.default_value()) {
962 continue;
963 } else {
964 throw schema_match_error(c10::str(
965 schema.name(),
966 "() is missing value for argument '",
967 arg.name(),
968 "'. Declaration: ",
969 schema));
970 }
971 }
972
973 if (consumed_kwargs != kwargs.size()) {
974 std::vector<std::string> names;
975 for (const auto& kwarg : kwargs) {
976 names.emplace_back(py::cast<std::string>(kwarg.first));
977 }
978 throw schema_match_error(schema.findErrorInKwargs(names));
979 }
980
981 return true;
982 }
983
createStackForSchema(const FunctionSchema & schema,const tuple_slice & args,const py::kwargs & kwargs,std::optional<IValue> self)984 inline Stack createStackForSchema(
985 const FunctionSchema& schema,
986 const tuple_slice& args,
987 const py::kwargs& kwargs,
988 std::optional<IValue> self) {
989 size_t all_arguments = (self ? 1 : 0) + args.size() + kwargs.size();
990 if (all_arguments > schema.arguments().size()) {
991 throw schema_match_error(c10::str(
992 schema.name(),
993 "() expected at most ",
994 schema.arguments().size(),
995 " argument(s) but received ",
996 all_arguments,
997 " argument(s). Declaration: ",
998 schema));
999 }
1000 Stack stack;
1001 stack.reserve(schema.arguments().size());
1002
1003 int64_t arg_idx = 0;
1004 if (self) {
1005 push(stack, std::move(*self));
1006 arg_idx++;
1007 }
1008 // First push all positional args.
1009 for (const auto& arg : args) {
1010 // ...but refuse to do it if the schema says that this was supposed
1011 // to be keyword only
1012 if (schema.arguments()[arg_idx].kwarg_only()) {
1013 throw schema_match_error(c10::str(
1014 schema.name(),
1015 "() takes ",
1016 arg_idx,
1017 " positional argument(s) but ",
1018 self ? 1 + args.size() : args.size(),
1019 " was/were given. Declaration: ",
1020 schema));
1021 }
1022 // Use the type information from the schema to convert the PyObject.
1023 push(stack, argumentToIValue(schema, stack.size(), arg));
1024 arg_idx++;
1025 }
1026
1027 // Now for every remaining non-positional argument in the schema, look for it
1028 // in the kwargs dict and push it if found, or use its default value if it
1029 // has one.
1030 size_t consumed_kwargs = 0;
1031 for (size_t i = stack.size(); i < schema.arguments().size(); ++i) {
1032 const auto& arg = schema.arguments()[i];
1033 if (kwargs.contains(arg.name().c_str())) {
1034 push(stack, argumentToIValue(schema, i, kwargs[arg.name().c_str()]));
1035 consumed_kwargs += 1;
1036 } else if (arg.default_value()) {
1037 push(stack, *arg.default_value());
1038 } else {
1039 throw schema_match_error(c10::str(
1040 schema.name(),
1041 "() is missing value for argument '",
1042 arg.name(),
1043 "'. Declaration: ",
1044 schema));
1045 }
1046 }
1047
1048 if (consumed_kwargs != kwargs.size()) {
1049 std::vector<std::string> names;
1050 for (const auto& kwarg : kwargs) {
1051 names.emplace_back(py::cast<std::string>(kwarg.first));
1052 }
1053 throw schema_match_error(schema.findErrorInKwargs(names));
1054 }
1055
1056 return stack;
1057 }
1058
createPyObjectForStack(Stack && stack)1059 inline py::object createPyObjectForStack(Stack&& stack) {
1060 if (stack.empty()) {
1061 return py::none();
1062 }
1063
1064 // Return a simple value and not a single-element tuple if there is only one
1065 // return value.
1066 if (stack.size() == 1) {
1067 return toPyObject(std::move(stack[0]));
1068 }
1069
1070 // If there is more than one return value, pop them into a py::tuple.
1071 py::tuple return_values(stack.size());
1072 for (const auto ret : c10::irange(return_values.size())) {
1073 return_values[ret] = toPyObject(std::move(stack[ret]));
1074 }
1075
1076 return std::move(return_values);
1077 }
1078
1079 // TODO: Remove once we clean up the GraphExecutor usage.
1080 inline Stack evilDeprecatedBadCreateStackDoNotUse(
1081 const py::tuple& tuple,
1082 at::ArrayRef<Value*> inputs,
1083 size_t reserve_extra_space = 0) {
1084 if (tuple.size() != inputs.size()) {
1085 AT_ERROR(
1086 "expected " + std::to_string(inputs.size()) + " inputs, but got " +
1087 std::to_string(tuple.size()));
1088 }
1089 Stack result;
1090 result.reserve(tuple.size() + reserve_extra_space);
1091 for (const auto i : c10::irange(inputs.size())) {
1092 result.push_back(toIValue(std::move(tuple[i]), inputs[i]->type()));
1093 }
1094 return result;
1095 }
1096
1097 // Run `callee`, potentially inserting a CallFunction/CallMethod node into the
1098 // tracing graph.
runAndInsertCall(Function & callee,const tuple_slice & args,const py::kwargs & kwargs,std::optional<IValue> self,const std::function<Value * (Graph &,const MatchedSchema & match)> & callInserter)1099 inline py::object runAndInsertCall(
1100 Function& callee,
1101 const tuple_slice& args,
1102 const py::kwargs& kwargs,
1103 std::optional<IValue> self,
1104 // Lambda that tells this function how to insert `callee` into the graph if
1105 // we're tracing.
1106 const std::function<Value*(Graph&, const MatchedSchema& match)>&
1107 callInserter) {
1108 auto stack =
1109 createStackForSchema(callee.getSchema(), args, kwargs, std::move(self));
1110 const auto& tracing_state = tracer::getTracingState();
1111 if (!tracing_state) {
1112 pybind11::gil_scoped_release no_gil_guard;
1113 // If we're not tracing, just run the callee as normal.
1114 callee.run(stack);
1115 } else {
1116 // If we are tracing, insert the appropriate CallFunction or CallMethod node
1117 // and then run the callee with tracing disabled.
1118
1119 // Get the graph `Value`s that represent the input IValues
1120 auto inputs = last(stack, callee.num_inputs());
1121 auto input_values =
1122 fmap(inputs, [](const IValue& v) { return tracer::getValueTrace(v); });
1123 TORCH_INTERNAL_ASSERT(callee.getSchema().returns().size() == 1)
1124 auto return_type = callee.getSchema().returns().at(0).type();
1125 auto graph = tracing_state->graph;
1126 std::vector<NamedValue> named_values;
1127 named_values.reserve(input_values.size());
1128 for (Value* v : input_values) {
1129 named_values.emplace_back(v);
1130 }
1131
1132 // Add a call node.
1133 MatchedSchema match = matchSchema(
1134 callee.getSchema(),
1135 tracer::getPythonInterpreterSourceRange(),
1136 *graph,
1137 named_values,
1138 {});
1139 auto output_value = callInserter(*graph, match);
1140
1141 // Actually run the callee. Pause the tracer so that we don't double-add the
1142 // callee nodes.
1143 {
1144 pybind11::gil_scoped_release no_gil_guard;
1145 ResourceGuard guard(tracer::pauseTracing());
1146 callee.run(stack);
1147 }
1148
1149 // Associate the output IValues with the output `Value`s in the graph
1150 tracer::setValueTrace(stack.back(), output_value);
1151 }
1152
1153 TORCH_CHECK(
1154 !stack.empty(),
1155 "Expected values in the stack after execution but found none");
1156 return toPyObject(std::move(stack.back()));
1157 }
1158
maybeTorchFunctionDispatch(const py::object & callee,const tuple_slice & args_no_self,const py::kwargs & kwargs,const c10::QualifiedName & qualname)1159 inline std::optional<py::object> maybeTorchFunctionDispatch(
1160 const py::object& callee,
1161 const tuple_slice& args_no_self,
1162 const py::kwargs& kwargs,
1163 const c10::QualifiedName& qualname) {
1164 std::vector<py::handle> args_vec;
1165 for (const auto& arg : args_no_self) {
1166 args_vec.push_back(arg);
1167 }
1168 py::tuple args = py::cast(args_vec);
1169
1170 // Handle __torch_function__ dispatch
1171 std::vector<PyObject*> overloaded_args;
1172 size_t total_arg_num = args.size() + kwargs.size();
1173 for (const auto& arg : args) {
1174 is_tensor_and_append_overloaded(arg.ptr(), &overloaded_args);
1175 is_tensor_list_and_append_overloaded(
1176 arg.ptr(),
1177 &overloaded_args,
1178 static_cast<int>(total_arg_num),
1179 false /* throw_error */);
1180 }
1181 // NB: for kwargs, we cannot guarantee the order of appending
1182 // is the same as the argument order in operator's schema.
1183 // This is suboptimal, but should be fine. Later when we have
1184 // better schema matching and argument parsing, we could
1185 // match the operator in `operations` first, then the order will
1186 // be guaranteed.
1187 for (auto item : kwargs) {
1188 is_tensor_and_append_overloaded(item.second.ptr(), &overloaded_args);
1189 is_tensor_list_and_append_overloaded(
1190 item.second.ptr(),
1191 &overloaded_args,
1192 total_arg_num,
1193 false /* throw_error */);
1194 }
1195 if (!overloaded_args.empty()) {
1196 return pybind11::reinterpret_steal<py::object>(
1197 handle_torch_function_no_python_arg_parser(
1198 /*overloaded_args=*/overloaded_args,
1199 /*args=*/args.ptr(),
1200 /*kwargs=*/kwargs.ptr(),
1201 /*func_name=*/qualname.name().c_str(),
1202 /*torch_api_function=*/callee.ptr(),
1203 /*module_name=*/qualname.prefix().c_str()));
1204 }
1205
1206 return std::nullopt;
1207 }
1208
invokeScriptFunctionFromPython(Function & callee,const tuple_slice & args,const py::kwargs & kwargs)1209 inline py::object invokeScriptFunctionFromPython(
1210 Function& callee,
1211 const tuple_slice& args,
1212 const py::kwargs& kwargs) {
1213 // TODO: we could add __torch_function__ dispatch here but I don't know
1214 // the implications of doing so
1215
1216 return runAndInsertCall(
1217 callee,
1218 args,
1219 kwargs,
1220 /*self=*/std::nullopt,
1221 [&](Graph& graph, const MatchedSchema& match) {
1222 return graph.insertFunctionCall(&callee, match);
1223 });
1224 }
1225
invokeScriptMethodFromPython(Method & callee,const tuple_slice & args,const py::kwargs & kwargs)1226 inline py::object invokeScriptMethodFromPython(
1227 Method& callee,
1228 const tuple_slice& args,
1229 const py::kwargs& kwargs) {
1230 auto self = callee.owner()._ivalue();
1231
1232 if (auto torch_fn_result = maybeTorchFunctionDispatch(
1233 py::cast(callee), args, kwargs, callee.name())) {
1234 return *torch_fn_result;
1235 }
1236
1237 return runAndInsertCall(
1238 callee.function(),
1239 args,
1240 kwargs,
1241 self,
1242 [&](Graph& graph, const MatchedSchema& match) {
1243 return graph.insertMethodCall(callee.name(), match);
1244 });
1245 }
1246
1247 TORCH_PYTHON_API std::pair<std::shared_ptr<Operator>, Stack> getOpWithStack(
1248 const std::vector<std::shared_ptr<Operator>>& operations,
1249 const py::args& args,
1250 const py::kwargs& kwargs);
1251
1252 TORCH_PYTHON_API py::object invokeOperatorFromPython(
1253 const std::vector<std::shared_ptr<Operator>>& operations,
1254 const py::args& args,
1255 const py::kwargs& kwargs,
1256 std::optional<c10::DispatchKey> dk = std::nullopt);
1257
1258 TORCH_PYTHON_API std::optional<py::object> _maybe_handle_torch_function(
1259 const std::string& ns,
1260 const std::string& method_name,
1261 const std::string& overload_name,
1262 bool is_overload,
1263 const py::args& args,
1264 const py::kwargs& kwargs);
1265
1266 TORCH_PYTHON_API bool checkSchemaAllowFakeScriptObject(
1267 const FunctionSchema& schema,
1268 const py::args& args,
1269 const py::kwargs& kwargs);
1270
1271 TORCH_PYTHON_API py::object _get_operation_for_overload_or_packet(
1272 const std::vector<std::shared_ptr<Operator>>& operations,
1273 Symbol symbol,
1274 const py::args& args,
1275 const py::kwargs& kwargs,
1276 bool is_overload,
1277 std::optional<c10::DispatchKey> dk = std::nullopt);
1278
1279 } // namespace torch::jit
1280