1 #include <torch/csrc/jit/frontend/function_schema_parser.h>
2 #include <torch/csrc/utils/python_dispatch.h>
3
4 #include <ATen/ATen.h>
5 #include <ATen/FuncTorchTLS.h>
6 #include <ATen/FunctionalTensorWrapper.h>
7 #include <ATen/TensorSubclassLikeUtils.h>
8 #include <ATen/core/NestedIntSymNodeImpl.h>
9 #include <ATen/core/PythonOpRegistrationTrampoline.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11
12 #include <ATen/functorch/BatchedTensorImpl.h>
13 #include <torch/library.h>
14
15 #include <c10/core/SafePyObject.h>
16 #include <torch/csrc/PyInterpreter.h>
17 #include <torch/csrc/autograd/python_variable.h>
18 #include <torch/csrc/jit/python/pybind_utils.h>
19 #include <torch/csrc/utils/tensor_new.h>
20
21 #include <c10/util/flat_hash_map.h>
22 #include <pybind11/operators.h>
23 #include <pybind11/stl.h>
24 #include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
25 #include <torch/csrc/utils/pybind.h>
26 #include <torch/csrc/utils/python_raii.h>
27
28 #include <iostream>
29 #include <utility>
30
31 namespace py = pybind11;
32
33 namespace torch::impl::dispatch {
34
35 // NB: I'd like to index this on OperatorHandle, but I can't, as I can't
36 // guarantee that the main interpreter has finish doing all registrations before
37 // the other interpreters start banging on it
38 static ska::flat_hash_map<
39 c10::OperatorName,
40 ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
41 python_registrations_;
42
parseKind(const std::string & k)43 static torch::Library::Kind parseKind(const std::string& k) {
44 static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
45 {"DEF", torch::Library::DEF},
46 {"IMPL", torch::Library::IMPL},
47 {"FRAGMENT", torch::Library::FRAGMENT},
48 };
49 auto it = kind_map.find(k);
50 TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
51 return it->second;
52 }
parseAliasAnalysisKind(const std::string & k)53 static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
54 static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
55 {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
56 {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
57 {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
58 {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
59 };
60 auto it = key_map.find(k);
61 TORCH_CHECK(it != key_map.end(), "could not parse ", k);
62 return it->second;
63 }
64
65 template <typename Func>
dispatch_str(const char * key,Func && raw_f)66 inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
67 if (key[0] != '\0') {
68 return torch::dispatch(
69 c10::parseDispatchKey(key), std::forward<Func>(raw_f));
70 } else {
71 torch::CppFunction f(std::forward<Func>(raw_f));
72 return f;
73 }
74 }
75
76 struct EnableHermeticPyObject {
EnableHermeticPyObjecttorch::impl::dispatch::EnableHermeticPyObject77 EnableHermeticPyObject()
78 : old_(c10::impl::HermeticPyObjectTLS::get_state()),
79 old_excluded_python_(
80 c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
81 old_python_(
82 c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
83 old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
84 at::DispatchKey::PythonTLSSnapshot)) {
85 c10::impl::HermeticPyObjectTLS::set_state(true);
86 c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
87 c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
88 c10::impl::tls_set_dispatch_key_included(
89 at::DispatchKey::PythonTLSSnapshot, false);
90 }
~EnableHermeticPyObjecttorch::impl::dispatch::EnableHermeticPyObject91 ~EnableHermeticPyObject() {
92 c10::impl::HermeticPyObjectTLS::set_state(old_);
93 c10::impl::tls_set_dispatch_key_excluded(
94 at::DispatchKey::Python, old_excluded_python_);
95 c10::impl::tls_set_dispatch_key_included(
96 at::DispatchKey::Python, old_python_);
97 c10::impl::tls_set_dispatch_key_included(
98 at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
99 }
100 bool old_;
101 bool old_excluded_python_;
102 bool old_python_;
103 bool old_python_snapshot_;
104 };
105
106 class PythonKernelHolder : public c10::OperatorKernel {
107 c10::SafePyObject func_;
108 c10::DispatchKey dispatch_key_;
109 // If "with_keyset", then we expect a keyset as the first arg.
110 bool with_keyset_;
111 // If "with_op", then we expect the op as first arg (or second if keyset)
112 bool with_op_;
113
114 public:
PythonKernelHolder(py::object func,c10::DispatchKey dispatch_key,bool with_keyset=false,bool with_op=false)115 PythonKernelHolder(
116 py::object func,
117 c10::DispatchKey dispatch_key,
118 bool with_keyset = false,
119 bool with_op = false)
120 : func_(func.release().ptr(), getPyInterpreter()),
121 dispatch_key_(dispatch_key),
122 with_keyset_(with_keyset),
123 with_op_(with_op) {}
124
operator ()(const c10::OperatorHandle & op,c10::DispatchKeySet keyset,torch::jit::Stack * stack)125 void operator()(
126 const c10::OperatorHandle& op,
127 c10::DispatchKeySet keyset,
128 torch::jit::Stack* stack) {
129 // Figure out if we can handle it hermetically, or if we have
130 // to double dispatch
131
132 // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
133 const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
134 if (mode_stack_len > 0) {
135 const auto& cur_torch_dispatch_mode_state =
136 c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
137 cur_torch_dispatch_mode_state->pyinterpreter()
138 ->python_op_registration_trampoline(
139 op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
140 return;
141 }
142
143 const auto& schema = op.schema();
144 const auto num_arguments = schema.arguments().size();
145
146 // Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
147 // means it's a nontrivial tensor subclass)
148 for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
149 if (ivalue.isTensor()) {
150 auto* interpreter =
151 ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
152 if (interpreter &&
153 ivalue.unsafeToTensorImpl()->key_set().has(
154 at::DispatchKey::Python)) {
155 (*interpreter)
156 ->python_op_registration_trampoline(
157 op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
158 return;
159 }
160 } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
161 // NB: use toListRef as it doesn't induce refcount bumps
162 // (toTensorListRef is not a thing)
163 for (const auto& nv : ivalue.toListRef()) {
164 if (nv.isNone()) {
165 continue;
166 }
167 auto* interpreter =
168 nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
169 if (interpreter &&
170 nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
171 (*interpreter)
172 ->python_op_registration_trampoline(
173 op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
174 return;
175 }
176 }
177 }
178 }
179
180 // Nothing requires the operator to be homed to a specific interpreter, so
181 // run it on the current interpreter
182
183 auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
184 py::gil_scoped_acquire g;
185 // Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic
186 // mode unconditionally in all situations when you're using multipy.
187 // Eventually just delete this entirely. (Note that you may break multipy
188 // anyway this way with dispatcher registered functions that require
189 // hermetic to be off.)
190 #if defined(USE_DEPLOY)
191 EnableHermeticPyObject g2;
192 #endif
193 auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
194 auto func =
195 py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
196 auto obj = with_op_ ? with_keyset_
197 ? func(
198 keyset,
199 torch::detail::getTorchApiFunction(op),
200 *args_kwargs.first,
201 **args_kwargs.second)
202 : func(
203 torch::detail::getTorchApiFunction(op),
204 *args_kwargs.first,
205 **args_kwargs.second)
206 : with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
207 : func(*args_kwargs.first, **args_kwargs.second);
208 if (!obj) {
209 throw python_error();
210 }
211 pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
212 }
213 };
214
register_or_verify()215 static torch::_RegisterOrVerify register_or_verify() {
216 if (isMainPyInterpreter()) {
217 return torch::_RegisterOrVerify::REGISTER;
218 } else {
219 return torch::_RegisterOrVerify::VERIFY;
220 }
221 }
222
ophandle_call_boxed(const c10::OperatorHandle & handle,const py::args & args,const py::kwargs & kwargs)223 static py::object ophandle_call_boxed(
224 const c10::OperatorHandle& handle,
225 const py::args& args,
226 const py::kwargs& kwargs) {
227 auto stack = torch::jit::createStackForSchema(
228 handle.schema(),
229 args,
230 kwargs,
231 /*self=*/std::nullopt);
232 {
233 pybind11::gil_scoped_release no_gil_guard;
234 handle.callBoxed(stack);
235 }
236 return torch::jit::createPyObjectForStack(std::move(stack));
237 }
238
239 // A small RAII guard that lets you explicitly *remove* a key from the TLS
240 // exclude set.
241 class SetExcludeDispatchKeyGuard {
242 public:
SetExcludeDispatchKeyGuard(at::DispatchKey k,bool set_excluded)243 SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded)
244 : k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) {
245 c10::impl::tls_set_dispatch_key_excluded(k, set_excluded);
246 }
~SetExcludeDispatchKeyGuard()247 ~SetExcludeDispatchKeyGuard() {
248 c10::impl::tls_set_dispatch_key_excluded(k, old);
249 }
250 SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete;
251 SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) =
252 delete;
253 SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete;
254 SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete;
255
256 private:
257 at::DispatchKey k;
258 bool old;
259 };
260
initDispatchBindings(PyObject * module)261 void initDispatchBindings(PyObject* module) {
262 auto m = py::handle(module).cast<py::module>();
263
264 py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
265 .def("schema", &c10::OperatorHandle::schema)
266 .def("debug", &c10::OperatorHandle::debug)
267 .def(
268 "redispatch_boxed",
269 [](const py::object& self,
270 c10::DispatchKeySet keyset,
271 py::args args,
272 const py::kwargs& kwargs) {
273 auto& handle = self.cast<c10::OperatorHandle&>();
274 auto stack = torch::jit::createStackForSchema(
275 handle.schema(),
276 std::move(args),
277 kwargs,
278 /*self=*/std::nullopt);
279 {
280 pybind11::gil_scoped_release no_gil_guard;
281 handle.redispatchBoxed(keyset, &stack);
282 }
283 return torch::jit::createPyObjectForStack(std::move(stack));
284 });
285
286 m.def("_dispatch_call_boxed", &ophandle_call_boxed);
287
288 // TODO: figure out how to do chaining
289 py::class_<torch::Library>(m, "_DispatchModule")
290 .def(
291 "reset",
292 [](const py::object& self) {
293 TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
294 self.cast<torch::Library&>().reset();
295 return;
296 },
297 "")
298 // Some of these APIs are only for testing and do not work in multipy
299 // environment
300 .def(
301 "def_",
302 [](py::object self, const char* schema, const char* alias) {
303 TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
304 self.cast<torch::Library&>().def(
305 torch::schema(schema, parseAliasAnalysisKind(alias)));
306 return self;
307 },
308 "",
309 py::arg("schema"),
310 py::arg("alias") = "")
311 // Simulated "legacy" def where alias analysis kind is not set.
312 // Ordinarily this can only be exercised from RegisterOperators() API
313 // but I am not going to bind that here
314 .def(
315 "def_legacy",
316 [](py::object self, const char* schema) {
317 TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
318 self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
319 return self;
320 },
321 "",
322 py::arg("schema"))
323 // We can't conveniently turn Python functions into valid functions
324 // in the dispatcher. So instead we provide a bunch of precanned
325 // functions for testing purposes. You're NOT intended to actually
326 // call these functions; they're just here so we can actually register
327 // something
328 //
329 // Mangling scheme: args_rets. One character per.
330 // t = Tensor
331 .def(
332 "def_name_t_t",
333 [](py::object self,
334 const char* name,
335 const char* dispatch,
336 const char* debug) {
337 TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
338 self.cast<torch::Library&>().def(
339 name, dispatch_str(dispatch, [](const at::Tensor& a) {
340 return a;
341 }).debug(debug));
342 return self;
343 },
344 "",
345 py::arg("name"),
346 py::arg("dispatch") = "",
347 py::arg("debug") = "default_def_name_t_t")
348 .def(
349 "def_schema_t_t",
350 [](py::object self,
351 const char* schema,
352 const char* dispatch,
353 const char* alias,
354 const char* debug) {
355 TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
356 self.cast<torch::Library&>().def(
357 torch::schema(schema, parseAliasAnalysisKind(alias)),
358 dispatch_str(dispatch, [](const at::Tensor& a) {
359 return a;
360 }).debug(debug));
361 return self;
362 },
363 "",
364 py::arg("name"),
365 py::arg("dispatch") = "",
366 py::arg("alias") = "",
367 py::arg("debug") = "default_def_schema_t_t")
368 // TODO: maybe consider deduplicating the definitions here, it's getting
369 // pretty long
370 .def(
371 "impl_t_t",
372 [](py::object self,
373 const char* name,
374 const char* dispatch,
375 const char* debug) {
376 TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
377 self.cast<torch::Library&>().impl(
378 name, dispatch_str(dispatch, [](const at::Tensor& a) {
379 return a;
380 }).debug(debug));
381 return self;
382 },
383 "",
384 py::arg("name"),
385 py::arg("dispatch") = "",
386 py::arg("debug") = "impl_t_t")
387 .def(
388 "impl_with_aoti_compile",
389 [](const py::object& self,
390 const char* ns,
391 const char* op_name_with_overload,
392 c10::DispatchKey dispatch) {
393 HANDLE_TH_ERRORS
394 std::string reg_op_name =
395 std::string(ns).append("::").append(op_name_with_overload);
396
397 auto& lib = self.cast<torch::Library&>();
398 lib.impl(
399 reg_op_name.c_str(),
400 torch::dispatch(
401 dispatch,
402 CppFunction::makeFromBoxedFunctor(
403 std::make_unique<
404 torch::inductor::AOTIPythonKernelHolder>(
405 dispatch, ns, op_name_with_overload))),
406 register_or_verify());
407 END_HANDLE_TH_ERRORS_PYBIND
408 },
409 "",
410 py::arg("ns"),
411 py::arg("op_name_with_overload"),
412 py::arg("dispatch"))
413 .def(
414 "impl",
415 [](const py::object& self,
416 const char* name,
417 // TODO: empty string no longer works
418 c10::DispatchKey dispatch,
419 py::object func,
420 bool with_keyset) {
421 HANDLE_TH_ERRORS
422 auto& lib = self.cast<torch::Library&>();
423 if (func.is(py::module::import("torch.library")
424 .attr("fallthrough_kernel"))) {
425 lib.impl(
426 name,
427 torch::dispatch(dispatch, CppFunction::makeFallthrough()),
428 register_or_verify());
429 } else {
430 lib.impl(
431 name,
432 torch::dispatch(
433 dispatch,
434 CppFunction::makeFromBoxedFunctor(
435 std::make_unique<PythonKernelHolder>(
436 func, dispatch, with_keyset))),
437 register_or_verify());
438 python_registrations_[lib._resolve(name)].insert_or_assign(
439 dispatch,
440 std::make_shared<c10::SafePyObject>(
441 func.release().ptr(), getPyInterpreter()));
442 }
443 END_HANDLE_TH_ERRORS_PYBIND
444 },
445 "",
446 py::arg("name"),
447 py::arg("dispatch"),
448 py::arg("func"),
449 py::arg("with_keyset") = false)
450 .def(
451 "define",
452 [](const py::object& self,
453 const char* schema,
454 const char* alias_analysis,
455 const std::vector<at::Tag>& tags) {
456 auto parsed_schema =
457 torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
458 self.cast<torch::Library&>().def(
459 std::move(parsed_schema), tags, register_or_verify());
460 // TODO: this is dumb, had to make a second copy
461 return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
462 .name();
463 },
464 "",
465 py::arg("schema"),
466 py::arg("alias_analysis") = "",
467 py::arg("tags") = std::vector<at::Tag>())
468 .def(
469 "fallback_fallthrough",
470 [](py::object self, const char* dispatch) {
471 TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
472 self.cast<torch::Library&>().fallback(
473 dispatch_str(dispatch, CppFunction::makeFallthrough()));
474 return self;
475 },
476 "",
477 py::arg("dispatch") = "")
478 .def(
479 "fallback",
480 [](const py::object& self,
481 c10::DispatchKey dispatch,
482 const py::object& func,
483 bool with_keyset) {
484 HANDLE_TH_ERRORS
485 auto& lib = self.cast<torch::Library&>();
486 TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
487 if (func.is(py::module::import("torch.library")
488 .attr("fallthrough_kernel"))) {
489 lib.fallback(
490 torch::dispatch(dispatch, CppFunction::makeFallthrough()));
491 } else {
492 lib.fallback(torch::dispatch(
493 dispatch,
494 CppFunction::makeFromBoxedFunctor(
495 std::make_unique<PythonKernelHolder>(
496 func, dispatch, with_keyset, /*with_op*/ true))));
497 }
498 END_HANDLE_TH_ERRORS_PYBIND
499 },
500 "",
501 py::arg("dispatch"),
502 py::arg("func"),
503 py::arg("with_keyset") = false);
504
505 m.def(
506 "_dispatch_library",
507 [](const char* kind,
508 std::string name,
509 const char* dispatch,
510 const char* file,
511 uint32_t linenum) {
512 HANDLE_TH_ERRORS
513 return std::make_unique<torch::Library>(
514 parseKind(kind),
515 std::move(name),
516 std::string(dispatch).empty()
517 ? std::nullopt
518 : std::make_optional(c10::parseDispatchKey(dispatch)),
519 "/dev/null", // temporary workaround
520 linenum);
521 END_HANDLE_TH_ERRORS_PYBIND
522 },
523 "",
524 py::arg("kind"),
525 py::arg("name"),
526 py::arg("dispatch"),
527 py::arg("file") = "/dev/null",
528 py::arg("linenum") = 0);
529
530 m.def(
531 "_dispatch_find_schema_or_throw",
532 [](const char* name, const char* overload_name) -> c10::OperatorHandle {
533 return c10::Dispatcher::singleton().findSchemaOrThrow(
534 name, overload_name);
535 });
536
537 m.def("_dispatch_dump", [](const char* name) -> std::string {
538 auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
539 if (!op) {
540 return "";
541 } else {
542 return op->dumpState();
543 }
544 });
545
546 m.def("_dispatch_dump_table", [](const char* name) -> std::string {
547 auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
548 if (!op) {
549 return "";
550 } else {
551 return op->dumpComputedTable();
552 }
553 });
554
555 m.def("_dispatch_check_invariants", [](const char* name) {
556 auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
557 if (!op) {
558 } else {
559 return op->checkInvariants();
560 }
561 });
562
563 m.def("_dispatch_check_all_invariants", []() {
564 c10::Dispatcher::singleton().checkInvariants();
565 });
566
567 m.def("_dispatch_has_kernel", [](const char* name) -> bool {
568 auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
569 return static_cast<bool>(op);
570 });
571
572 m.def(
573 // Returns whether or not a direct kernel registration exists
574 // for this <op_name, dispatch_key> pair.
575 "_dispatch_has_kernel_for_dispatch_key",
576 [](const char* name, c10::DispatchKey dispatch) -> bool {
577 auto op =
578 c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
579 TORCH_CHECK(op, "operator ", name, " does not exist");
580 return op->hasKernelForDispatchKey(dispatch);
581 });
582
583 m.def(
584 // Returns whether or not the kernel for this dispatach key is a
585 // fallthrough kernel
586 "_dispatch_kernel_for_dispatch_key_is_fallthrough",
587 [](const char* name, c10::DispatchKey dispatch) -> bool {
588 auto op =
589 c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
590 return op->isKernelFallthroughKernel(dispatch);
591 });
592
593 m.def(
594 "_dispatch_has_kernel_for_any_dispatch_key",
595 [](const char* name, c10::DispatchKeySet ks) -> bool {
596 auto op =
597 c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
598 TORCH_CHECK(op, "operator ", name, " does not exist");
599 return op->hasKernelForAnyDispatchKey(ks);
600 });
601
602 m.def(
603 // Returns whether or not there is an entry in the runtime computed
604 // dispatch table, for this <op_name, dispatch_key> pair. For example, if
605 // "op" has a `CompositeImplicitAutograd` kernel, Then
606 // _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
607 // true for all backends that are part of the alias set for
608 // CompositeImplicitAutograd.
609 "_dispatch_has_computed_kernel_for_dispatch_key",
610 [](const char* name, const char* dispatch) -> bool {
611 auto op =
612 c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
613 TORCH_CHECK(op, "operator ", name, " does not exist");
614 return op->hasComputedKernelForDispatchKey(
615 c10::parseDispatchKey(dispatch));
616 });
617
618 m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
619 auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
620
621 std::vector<std::string> states;
622 states.reserve(danglingImpls.size());
623 for (auto& danglingImpl : danglingImpls) {
624 states.emplace_back(danglingImpl.dumpState());
625 }
626
627 return states;
628 });
629
630 m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
631 auto op_names = c10::Dispatcher::singleton().getAllOpNames();
632
633 std::vector<std::string> names;
634 names.reserve(op_names.size());
635 for (auto& op : op_names) {
636 std::stringstream ss;
637 ss << op.name;
638 if (!op.overload_name.empty()) {
639 ss << "." << op.overload_name;
640 }
641 names.emplace_back(ss.str());
642 }
643
644 return names;
645 });
646
647 m.def(
648 "_dispatch_tls_set_dispatch_key_excluded",
649 [](c10::DispatchKey dispatch_key, bool desired_state) {
650 c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
651 });
652 m.def(
653 "_dispatch_tls_is_dispatch_key_excluded",
654 [](c10::DispatchKey dispatch_key) {
655 return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
656 });
657 m.def(
658 "_dispatch_tls_set_dispatch_key_included",
659 [](c10::DispatchKey dispatch_key, bool desired_state) {
660 c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
661 });
662 m.def(
663 "_dispatch_tls_is_dispatch_key_included",
664 [](c10::DispatchKey dispatch_key) {
665 return c10::impl::tls_is_dispatch_key_included(dispatch_key);
666 });
667
668 m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
669 return at::isTensorSubclassLike(tensor);
670 });
671
672 m.def("_dispatch_key_name", [](c10::DispatchKey k) {
673 return c10::toString(k);
674 });
675 m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
676 m.def("_to_functionality_key", [](c10::DispatchKey k) {
677 return c10::toFunctionalityKey(k);
678 });
679 // E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
680 // AutogradCPU
681 // AutogradCUDA
682 // ...
683 // AutogradPrivateUse3
684 m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
685 std::vector<c10::DispatchKey> keys;
686 if (c10::isPerBackendFunctionalityKey(key)) {
687 auto ks = c10::DispatchKeySet(key) |
688 c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
689 for (auto k : ks) {
690 keys.push_back(k);
691 }
692 } else {
693 keys.push_back(key);
694 }
695 return keys;
696 });
697 m.def("_dispatch_num_backends", []() { return c10::num_backends; });
698
699 #define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
700
701 py::enum_<c10::DispatchKey>(m, "DispatchKey")
702 // clang-format off
703 DEF_ONE(Undefined)
704 DEF_ONE(CompositeExplicitAutogradNonFunctional)
705 DEF_ONE(CompositeExplicitAutograd)
706 DEF_ONE(CompositeImplicitAutogradNestedTensor)
707 DEF_ONE(CompositeImplicitAutograd)
708 // NestedTensor is not a backend key
709 DEF_ONE(AutogradNestedTensor)
710 DEF_ONE(AutogradOther)
711 DEF_ONE(Autograd)
712 DEF_ONE(Conjugate)
713 DEF_ONE(ZeroTensor)
714 DEF_ONE(Negative)
715 DEF_ONE(BackendSelect)
716 DEF_ONE(ADInplaceOrView)
717 DEF_ONE(PythonTLSSnapshot)
718 DEF_ONE(Python)
719 DEF_ONE(FuncTorchDynamicLayerFrontMode)
720 DEF_ONE(FuncTorchDynamicLayerBackMode)
721 DEF_ONE(FuncTorchBatchedDecomposition)
722 DEF_ONE(FuncTorchBatched)
723 DEF_ONE(FuncTorchVmapMode)
724 DEF_ONE(FuncTorchGradWrapper)
725 DEF_ONE(PythonDispatcher)
726 DEF_ONE(PreDispatch)
727 DEF_ONE(Functionalize)
728 DEF_ONE(AutocastCPU)
729 DEF_ONE(AutocastMPS)
730 DEF_ONE(AutocastXPU)
731 DEF_ONE(AutocastHPU)
732 DEF_ONE(AutocastIPU)
733 DEF_ONE(AutocastCUDA)
734 DEF_ONE(AutocastPrivateUse1)
735 // clang-format on
736
737 #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
738 #define DEF_MULTIPLE(fullname, prefix) \
739 DEF_SINGLE(, fullname) \
740 DEF_SINGLE(, StartOf##fullname##Backends) \
741 C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
742 DEF_SINGLE(, EndOf##fullname##Backends)
743
744 // clang-format off
745 C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
746 // clang-format on
747
748 #undef DEF_MULTIPLE
749 #undef DEF_SINGLE
750 ;
751
752 py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
753 .def(py::init<c10::DispatchKey>())
754 .def("__or__", &c10::DispatchKeySet::operator|)
755 .def("__sub__", &c10::DispatchKeySet::operator-)
756 .def("__and__", &c10::DispatchKeySet::operator&)
757 .def("raw_repr", &c10::DispatchKeySet::raw_repr)
758 .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
759 .def(
760 "remove",
761 [](c10::DispatchKeySet self, c10::DispatchKey k) {
762 return self.remove(k);
763 })
764 .def(
765 "add",
766 [](c10::DispatchKeySet self, c10::DispatchKey k) {
767 return self.add(k);
768 })
769 .def("has", &c10::DispatchKeySet::has)
770 .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
771
772 m.attr("_dispatch_autogradother_backends") =
773 py::cast(c10::autogradother_backends);
774
775 m.attr("_additional_keys_to_prop_for_wrapper_tensors") =
776 py::cast(at::functorch::kKeysToPropagateToWrapper);
777
778 m.attr("_after_autograd_keyset") = py::cast(c10::after_autograd_keyset);
779 m.attr("_after_ADInplaceOrView_keyset") =
780 py::cast(c10::after_ADInplaceOrView_keyset);
781
782 m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
783 return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
784 });
785
786 m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
787 return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
788 });
789
790 m.def("_dispatch_keyset_full", []() {
791 return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
792 });
793
794 m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
795
796 m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
797 return c10::toString(keyset);
798 });
799
800 m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
801 return c10::getBackendKeySetFromAutograd(k);
802 });
803
804 m.def("_dispatch_keys", [](const at::Tensor& tensor) {
805 auto* impl = tensor.unsafeGetTensorImpl();
806 return impl->key_set();
807 });
808 m.def("_dispatch_tls_local_include_set", []() {
809 return c10::impl::tls_local_dispatch_key_set().included_;
810 });
811 m.def("_dispatch_tls_local_exclude_set", []() {
812 return c10::impl::tls_local_dispatch_key_set().excluded_;
813 });
814 m.def("_functionalization_reapply_views_tls", []() {
815 return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
816 });
817 m.def(
818 "_dispatch_is_included_in_alias",
819 [](c10::DispatchKey a, c10::DispatchKey b) {
820 return c10::isIncludedInAlias(a, b);
821 });
822
823 // DEPRECATED, please don't use this. Instead use
824 // torch._C._ExcludeDispatchKeyGuard
825 py_context_manager_DEPRECATED<
826 c10::impl::ExcludeDispatchKeyGuard,
827 c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
828
829 py_context_manager<
830 c10::impl::ForceDispatchKeyGuard,
831 c10::DispatchKeySet,
832 c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
833 py_context_manager<c10::impl::ForceDispatchKeyGuard>(
834 m, "_PreserveDispatchKeyGuard");
835 py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
836 m, "_IncludeDispatchKeyGuard");
837 py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
838 m, "_ExcludeDispatchKeyGuard");
839 py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>(
840 m, "_SetExcludeDispatchKeyGuard");
841
842 py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
843 m, "_AutoDispatchBelowAutograd");
844 py_context_manager<at::AutoDispatchBelowADInplaceOrView>(
845 m, "_AutoDispatchBelowADInplaceOrView");
846
847 // Prints out the name of every operator that has a kernel registered to the
848 // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
849 // out the name of every operator that the Dispatcher knows of. This can be
850 // useful to answer questions like "list all operators that do not have a CPU
851 // kernel".
852 m.def(
853 "_dispatch_print_registrations_for_dispatch_key",
854 [](const char* dispatch_key = "") {
855 auto k = std::string(dispatch_key).empty()
856 ? std::nullopt
857 : std::make_optional(c10::parseDispatchKey(dispatch_key));
858 auto op_names =
859 c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
860 for (auto& op : op_names) {
861 std::cout << op << '\n';
862 }
863 },
864 py::arg("dispatch_key") = static_cast<const char*>(""));
865
866 m.def(
867 "_parse_dispatch_key",
868 [](const char* dispatch_key) -> std::optional<c10::DispatchKey> {
869 try {
870 return c10::parseDispatchKey(dispatch_key);
871 } catch (const c10::Error& err) {
872 return std::nullopt;
873 }
874 });
875
876 m.def(
877 "_dispatch_get_registrations_for_dispatch_key",
878 [](const char* dispatch_key = "") {
879 auto k = std::string(dispatch_key).empty()
880 ? std::nullopt
881 : std::make_optional(c10::parseDispatchKey(dispatch_key));
882 auto op_names =
883 c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
884 std::vector<std::string> names;
885 names.reserve(op_names.size());
886 for (auto& op : op_names) {
887 names.emplace_back(
888 op.name +
889 (op.overload_name.empty() ? "" : "." + op.overload_name));
890 }
891 return names;
892 },
893 py::arg("dispatch_key") = static_cast<const char*>(""));
894 m.def(
895 "_dispatch_set_report_error_callback",
896 [](c10::OperatorHandle& handle, py::object callback) {
897 auto obj = callback.release().ptr();
898 auto callback_obj =
899 std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
900 handle.setReportErrorCallback_(std::move(callback_obj));
901 });
902
903 m.def(
904 "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
905 m.def("_dispatch_pystub", [](const char* name, const char* overload) {
906 return c10::Dispatcher::singleton().getPyStub(
907 c10::OperatorName(name, overload));
908 });
909
910 m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
911 return at::functionalization::impl::replace_(a, b);
912 });
913 m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
914 at::functionalization::impl::propagate_xla_data(a, b);
915 });
916 m.def("_commit_update", [](const at::Tensor& a) {
917 return at::functionalization::impl::commit_update(a);
918 });
919 m.def("_unsafe_reset_storage", [](const at::Tensor& a) {
920 return at::functionalization::impl::unsafe_reset_storage(a);
921 });
922
923 m.def("_dispatch_key_for_device", [](const std::string& device_type) {
924 auto device = c10::Device(device_type);
925 TORCH_CHECK(
926 !device.has_index(),
927 "Expected device_type string to not have a device index; got ",
928 device_type);
929 return c10::toString(
930 c10::computeDispatchKey(std::nullopt, std::nullopt, device));
931 });
932
933 m.def("_are_functorch_transforms_active", []() {
934 auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
935 return (
936 include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
937 include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
938 });
939
940 m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
941 return c10::SymInt(c10::SymNode(
942 c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
943 });
944
945 m.def("_get_constant_bool_symnode", [](int64_t data) {
946 return c10::SymNode(
947 c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
948 });
949
950 m.def("_non_sym_sizes", [](const at::Tensor& a) {
951 return a.sizes(); // NB: NOT sym_size
952 });
953
954 m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) {
955 if (!t.unsafeGetTensorImpl()->has_storage()) {
956 // If the Tensor doesn't have a storage, then accessing .data_ptr()
957 // will already raise an error.
958 return;
959 }
960 // Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr
961 // will throw.
962 t.unsafeGetTensorImpl()
963 ->storage()
964 .unsafeGetStorageImpl()
965 ->set_throw_on_mutable_data_ptr();
966 });
967
968 // Invariant: you must ONLY call this with FakeTensors.
969 m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) {
970 if (!t.unsafeGetTensorImpl()->has_storage()) {
971 // If the Tensor doesn't have a storage, then accessing .data_ptr()
972 // will already raise an error.
973 return;
974 }
975 t.unsafeGetTensorImpl()
976 ->storage()
977 .unsafeGetStorageImpl()
978 ->set_warn_deprecated_on_mutable_data_ptr();
979 });
980
981 m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
982 m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
983
984 using c10::impl::TorchDispatchModeKey;
985 py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
986 .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
987 .value("PROXY", TorchDispatchModeKey::PROXY)
988 .value("FAKE", TorchDispatchModeKey::FAKE);
989 }
990
991 // TODO: dedupe with the kernel
python_op_registration_trampoline_impl(const c10::OperatorHandle & op,c10::DispatchKey key,c10::DispatchKeySet keyset,torch::jit::Stack * stack,bool with_keyset,bool with_op)992 void python_op_registration_trampoline_impl(
993 const c10::OperatorHandle& op,
994 c10::DispatchKey key,
995 c10::DispatchKeySet keyset,
996 torch::jit::Stack* stack,
997 bool with_keyset,
998 bool with_op) {
999 auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
1000 py::gil_scoped_acquire g;
1001 auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
1002 const auto& func = python_registrations_[op.operator_name()][key];
1003 TORCH_INTERNAL_ASSERT(func != nullptr);
1004 auto* pyobj = func->ptr(getPyInterpreter());
1005 TORCH_INTERNAL_ASSERT(pyobj != nullptr);
1006 auto callable = py::reinterpret_borrow<py::object>(pyobj);
1007 auto obj = with_op ? with_keyset ? callable(
1008 keyset,
1009 torch::detail::getTorchApiFunction(op),
1010 *args_kwargs.first,
1011 **args_kwargs.second)
1012 : callable(
1013 torch::detail::getTorchApiFunction(op),
1014 *args_kwargs.first,
1015 **args_kwargs.second)
1016 : with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
1017 : callable(*args_kwargs.first, **args_kwargs.second);
1018 if (!obj) {
1019 throw python_error();
1020 }
1021 pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
1022 }
1023
1024 } // namespace torch::impl::dispatch
1025