xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/python_dispatch.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/frontend/function_schema_parser.h>
2 #include <torch/csrc/utils/python_dispatch.h>
3 
4 #include <ATen/ATen.h>
5 #include <ATen/FuncTorchTLS.h>
6 #include <ATen/FunctionalTensorWrapper.h>
7 #include <ATen/TensorSubclassLikeUtils.h>
8 #include <ATen/core/NestedIntSymNodeImpl.h>
9 #include <ATen/core/PythonOpRegistrationTrampoline.h>
10 #include <ATen/core/dispatch/Dispatcher.h>
11 
12 #include <ATen/functorch/BatchedTensorImpl.h>
13 #include <torch/library.h>
14 
15 #include <c10/core/SafePyObject.h>
16 #include <torch/csrc/PyInterpreter.h>
17 #include <torch/csrc/autograd/python_variable.h>
18 #include <torch/csrc/jit/python/pybind_utils.h>
19 #include <torch/csrc/utils/tensor_new.h>
20 
21 #include <c10/util/flat_hash_map.h>
22 #include <pybind11/operators.h>
23 #include <pybind11/stl.h>
24 #include <torch/csrc/inductor/aoti_eager/kernel_holder.h>
25 #include <torch/csrc/utils/pybind.h>
26 #include <torch/csrc/utils/python_raii.h>
27 
28 #include <iostream>
29 #include <utility>
30 
31 namespace py = pybind11;
32 
33 namespace torch::impl::dispatch {
34 
35 // NB: I'd like to index this on OperatorHandle, but I can't, as I can't
36 // guarantee that the main interpreter has finish doing all registrations before
37 // the other interpreters start banging on it
38 static ska::flat_hash_map<
39     c10::OperatorName,
40     ska::flat_hash_map<c10::DispatchKey, std::shared_ptr<c10::SafePyObject>>>
41     python_registrations_;
42 
parseKind(const std::string & k)43 static torch::Library::Kind parseKind(const std::string& k) {
44   static std::unordered_map<std::string, torch::Library::Kind> kind_map = {
45       {"DEF", torch::Library::DEF},
46       {"IMPL", torch::Library::IMPL},
47       {"FRAGMENT", torch::Library::FRAGMENT},
48   };
49   auto it = kind_map.find(k);
50   TORCH_CHECK(it != kind_map.end(), "could not parse ", k);
51   return it->second;
52 }
parseAliasAnalysisKind(const std::string & k)53 static c10::AliasAnalysisKind parseAliasAnalysisKind(const std::string& k) {
54   static std::unordered_map<std::string, c10::AliasAnalysisKind> key_map = {
55       {"CONSERVATIVE", c10::AliasAnalysisKind::CONSERVATIVE},
56       {"FROM_SCHEMA", c10::AliasAnalysisKind::FROM_SCHEMA},
57       {"PURE_FUNCTION", c10::AliasAnalysisKind::PURE_FUNCTION},
58       {"", c10::AliasAnalysisKind::FROM_SCHEMA}, // default
59   };
60   auto it = key_map.find(k);
61   TORCH_CHECK(it != key_map.end(), "could not parse ", k);
62   return it->second;
63 }
64 
65 template <typename Func>
dispatch_str(const char * key,Func && raw_f)66 inline torch::CppFunction dispatch_str(const char* key, Func&& raw_f) {
67   if (key[0] != '\0') {
68     return torch::dispatch(
69         c10::parseDispatchKey(key), std::forward<Func>(raw_f));
70   } else {
71     torch::CppFunction f(std::forward<Func>(raw_f));
72     return f;
73   }
74 }
75 
76 struct EnableHermeticPyObject {
EnableHermeticPyObjecttorch::impl::dispatch::EnableHermeticPyObject77   EnableHermeticPyObject()
78       : old_(c10::impl::HermeticPyObjectTLS::get_state()),
79         old_excluded_python_(
80             c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Python)),
81         old_python_(
82             c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Python)),
83         old_python_snapshot_(c10::impl::tls_is_dispatch_key_included(
84             at::DispatchKey::PythonTLSSnapshot)) {
85     c10::impl::HermeticPyObjectTLS::set_state(true);
86     c10::impl::tls_set_dispatch_key_excluded(at::DispatchKey::Python, true);
87     c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Python, false);
88     c10::impl::tls_set_dispatch_key_included(
89         at::DispatchKey::PythonTLSSnapshot, false);
90   }
~EnableHermeticPyObjecttorch::impl::dispatch::EnableHermeticPyObject91   ~EnableHermeticPyObject() {
92     c10::impl::HermeticPyObjectTLS::set_state(old_);
93     c10::impl::tls_set_dispatch_key_excluded(
94         at::DispatchKey::Python, old_excluded_python_);
95     c10::impl::tls_set_dispatch_key_included(
96         at::DispatchKey::Python, old_python_);
97     c10::impl::tls_set_dispatch_key_included(
98         at::DispatchKey::PythonTLSSnapshot, old_python_snapshot_);
99   }
100   bool old_;
101   bool old_excluded_python_;
102   bool old_python_;
103   bool old_python_snapshot_;
104 };
105 
106 class PythonKernelHolder : public c10::OperatorKernel {
107   c10::SafePyObject func_;
108   c10::DispatchKey dispatch_key_;
109   // If "with_keyset", then we expect a keyset as the first arg.
110   bool with_keyset_;
111   // If "with_op", then we expect the op as first arg (or second if keyset)
112   bool with_op_;
113 
114  public:
PythonKernelHolder(py::object func,c10::DispatchKey dispatch_key,bool with_keyset=false,bool with_op=false)115   PythonKernelHolder(
116       py::object func,
117       c10::DispatchKey dispatch_key,
118       bool with_keyset = false,
119       bool with_op = false)
120       : func_(func.release().ptr(), getPyInterpreter()),
121         dispatch_key_(dispatch_key),
122         with_keyset_(with_keyset),
123         with_op_(with_op) {}
124 
operator ()(const c10::OperatorHandle & op,c10::DispatchKeySet keyset,torch::jit::Stack * stack)125   void operator()(
126       const c10::OperatorHandle& op,
127       c10::DispatchKeySet keyset,
128       torch::jit::Stack* stack) {
129     // Figure out if we can handle it hermetically, or if we have
130     // to double dispatch
131 
132     // If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
133     const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
134     if (mode_stack_len > 0) {
135       const auto& cur_torch_dispatch_mode_state =
136           c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
137       cur_torch_dispatch_mode_state->pyinterpreter()
138           ->python_op_registration_trampoline(
139               op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
140       return;
141     }
142 
143     const auto& schema = op.schema();
144     const auto num_arguments = schema.arguments().size();
145 
146     // Otherwise, find a PyInterpreter on a Tensor IF if has Python key (which
147     // means it's a nontrivial tensor subclass)
148     for (const auto& ivalue : torch::jit::last(*stack, num_arguments)) {
149       if (ivalue.isTensor()) {
150         auto* interpreter =
151             ivalue.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
152         if (interpreter &&
153             ivalue.unsafeToTensorImpl()->key_set().has(
154                 at::DispatchKey::Python)) {
155           (*interpreter)
156               ->python_op_registration_trampoline(
157                   op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
158           return;
159         }
160       } else if (ivalue.isTensorList() || ivalue.isOptionalTensorList()) {
161         // NB: use toListRef as it doesn't induce refcount bumps
162         // (toTensorListRef is not a thing)
163         for (const auto& nv : ivalue.toListRef()) {
164           if (nv.isNone()) {
165             continue;
166           }
167           auto* interpreter =
168               nv.unsafeToTensorImpl()->pyobj_slot()->pyobj_interpreter();
169           if (interpreter &&
170               nv.unsafeToTensorImpl()->key_set().has(at::DispatchKey::Python)) {
171             (*interpreter)
172                 ->python_op_registration_trampoline(
173                     op, dispatch_key_, keyset, stack, with_keyset_, with_op_);
174             return;
175           }
176         }
177       }
178     }
179 
180     // Nothing requires the operator to be homed to a specific interpreter, so
181     // run it on the current interpreter
182 
183     auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
184     py::gil_scoped_acquire g;
185     // Jan 2024: We're slated to get rid of multipy, so stop forcing hermetic
186     // mode unconditionally in all situations when you're using multipy.
187     // Eventually just delete this entirely.  (Note that you may break multipy
188     // anyway this way with dispatcher registered functions that require
189     // hermetic to be off.)
190 #if defined(USE_DEPLOY)
191     EnableHermeticPyObject g2;
192 #endif
193     auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
194     auto func =
195         py::reinterpret_borrow<py::object>(func_.ptr(getPyInterpreter()));
196     auto obj = with_op_ ? with_keyset_
197             ? func(
198                   keyset,
199                   torch::detail::getTorchApiFunction(op),
200                   *args_kwargs.first,
201                   **args_kwargs.second)
202             : func(
203                   torch::detail::getTorchApiFunction(op),
204                   *args_kwargs.first,
205                   **args_kwargs.second)
206         : with_keyset_ ? func(keyset, *args_kwargs.first, **args_kwargs.second)
207                         : func(*args_kwargs.first, **args_kwargs.second);
208     if (!obj) {
209       throw python_error();
210     }
211     pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
212   }
213 };
214 
register_or_verify()215 static torch::_RegisterOrVerify register_or_verify() {
216   if (isMainPyInterpreter()) {
217     return torch::_RegisterOrVerify::REGISTER;
218   } else {
219     return torch::_RegisterOrVerify::VERIFY;
220   }
221 }
222 
ophandle_call_boxed(const c10::OperatorHandle & handle,const py::args & args,const py::kwargs & kwargs)223 static py::object ophandle_call_boxed(
224     const c10::OperatorHandle& handle,
225     const py::args& args,
226     const py::kwargs& kwargs) {
227   auto stack = torch::jit::createStackForSchema(
228       handle.schema(),
229       args,
230       kwargs,
231       /*self=*/std::nullopt);
232   {
233     pybind11::gil_scoped_release no_gil_guard;
234     handle.callBoxed(stack);
235   }
236   return torch::jit::createPyObjectForStack(std::move(stack));
237 }
238 
239 // A small RAII guard that lets you explicitly *remove* a key from the TLS
240 // exclude set.
241 class SetExcludeDispatchKeyGuard {
242  public:
SetExcludeDispatchKeyGuard(at::DispatchKey k,bool set_excluded)243   SetExcludeDispatchKeyGuard(at::DispatchKey k, bool set_excluded)
244       : k(k), old(c10::impl::tls_is_dispatch_key_excluded(k)) {
245     c10::impl::tls_set_dispatch_key_excluded(k, set_excluded);
246   }
~SetExcludeDispatchKeyGuard()247   ~SetExcludeDispatchKeyGuard() {
248     c10::impl::tls_set_dispatch_key_excluded(k, old);
249   }
250   SetExcludeDispatchKeyGuard(const SetExcludeDispatchKeyGuard&) = delete;
251   SetExcludeDispatchKeyGuard operator=(const SetExcludeDispatchKeyGuard&) =
252       delete;
253   SetExcludeDispatchKeyGuard(SetExcludeDispatchKeyGuard&&) = delete;
254   SetExcludeDispatchKeyGuard operator=(SetExcludeDispatchKeyGuard&&) = delete;
255 
256  private:
257   at::DispatchKey k;
258   bool old;
259 };
260 
initDispatchBindings(PyObject * module)261 void initDispatchBindings(PyObject* module) {
262   auto m = py::handle(module).cast<py::module>();
263 
264   py::class_<c10::OperatorHandle>(m, "_DispatchOperatorHandle")
265       .def("schema", &c10::OperatorHandle::schema)
266       .def("debug", &c10::OperatorHandle::debug)
267       .def(
268           "redispatch_boxed",
269           [](const py::object& self,
270              c10::DispatchKeySet keyset,
271              py::args args,
272              const py::kwargs& kwargs) {
273             auto& handle = self.cast<c10::OperatorHandle&>();
274             auto stack = torch::jit::createStackForSchema(
275                 handle.schema(),
276                 std::move(args),
277                 kwargs,
278                 /*self=*/std::nullopt);
279             {
280               pybind11::gil_scoped_release no_gil_guard;
281               handle.redispatchBoxed(keyset, &stack);
282             }
283             return torch::jit::createPyObjectForStack(std::move(stack));
284           });
285 
286   m.def("_dispatch_call_boxed", &ophandle_call_boxed);
287 
288   // TODO: figure out how to do chaining
289   py::class_<torch::Library>(m, "_DispatchModule")
290       .def(
291           "reset",
292           [](const py::object& self) {
293             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
294             self.cast<torch::Library&>().reset();
295             return;
296           },
297           "")
298       // Some of these APIs are only for testing and do not work in multipy
299       // environment
300       .def(
301           "def_",
302           [](py::object self, const char* schema, const char* alias) {
303             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
304             self.cast<torch::Library&>().def(
305                 torch::schema(schema, parseAliasAnalysisKind(alias)));
306             return self;
307           },
308           "",
309           py::arg("schema"),
310           py::arg("alias") = "")
311       // Simulated "legacy" def where alias analysis kind is not set.
312       // Ordinarily this can only be exercised from RegisterOperators() API
313       // but I am not going to bind that here
314       .def(
315           "def_legacy",
316           [](py::object self, const char* schema) {
317             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
318             self.cast<torch::Library&>().def(torch::jit::parseSchema(schema));
319             return self;
320           },
321           "",
322           py::arg("schema"))
323       // We can't conveniently turn Python functions into valid functions
324       // in the dispatcher.  So instead we provide a bunch of precanned
325       // functions for testing purposes.  You're NOT intended to actually
326       // call these functions; they're just here so we can actually register
327       // something
328       //
329       // Mangling scheme: args_rets.  One character per.
330       //  t = Tensor
331       .def(
332           "def_name_t_t",
333           [](py::object self,
334              const char* name,
335              const char* dispatch,
336              const char* debug) {
337             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
338             self.cast<torch::Library&>().def(
339                 name, dispatch_str(dispatch, [](const at::Tensor& a) {
340                         return a;
341                       }).debug(debug));
342             return self;
343           },
344           "",
345           py::arg("name"),
346           py::arg("dispatch") = "",
347           py::arg("debug") = "default_def_name_t_t")
348       .def(
349           "def_schema_t_t",
350           [](py::object self,
351              const char* schema,
352              const char* dispatch,
353              const char* alias,
354              const char* debug) {
355             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
356             self.cast<torch::Library&>().def(
357                 torch::schema(schema, parseAliasAnalysisKind(alias)),
358                 dispatch_str(dispatch, [](const at::Tensor& a) {
359                   return a;
360                 }).debug(debug));
361             return self;
362           },
363           "",
364           py::arg("name"),
365           py::arg("dispatch") = "",
366           py::arg("alias") = "",
367           py::arg("debug") = "default_def_schema_t_t")
368       // TODO: maybe consider deduplicating the definitions here, it's getting
369       // pretty long
370       .def(
371           "impl_t_t",
372           [](py::object self,
373              const char* name,
374              const char* dispatch,
375              const char* debug) {
376             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
377             self.cast<torch::Library&>().impl(
378                 name, dispatch_str(dispatch, [](const at::Tensor& a) {
379                         return a;
380                       }).debug(debug));
381             return self;
382           },
383           "",
384           py::arg("name"),
385           py::arg("dispatch") = "",
386           py::arg("debug") = "impl_t_t")
387       .def(
388           "impl_with_aoti_compile",
389           [](const py::object& self,
390              const char* ns,
391              const char* op_name_with_overload,
392              c10::DispatchKey dispatch) {
393             HANDLE_TH_ERRORS
394             std::string reg_op_name =
395                 std::string(ns).append("::").append(op_name_with_overload);
396 
397             auto& lib = self.cast<torch::Library&>();
398             lib.impl(
399                 reg_op_name.c_str(),
400                 torch::dispatch(
401                     dispatch,
402                     CppFunction::makeFromBoxedFunctor(
403                         std::make_unique<
404                             torch::inductor::AOTIPythonKernelHolder>(
405                             dispatch, ns, op_name_with_overload))),
406                 register_or_verify());
407             END_HANDLE_TH_ERRORS_PYBIND
408           },
409           "",
410           py::arg("ns"),
411           py::arg("op_name_with_overload"),
412           py::arg("dispatch"))
413       .def(
414           "impl",
415           [](const py::object& self,
416              const char* name,
417              // TODO: empty string no longer works
418              c10::DispatchKey dispatch,
419              py::object func,
420              bool with_keyset) {
421             HANDLE_TH_ERRORS
422             auto& lib = self.cast<torch::Library&>();
423             if (func.is(py::module::import("torch.library")
424                             .attr("fallthrough_kernel"))) {
425               lib.impl(
426                   name,
427                   torch::dispatch(dispatch, CppFunction::makeFallthrough()),
428                   register_or_verify());
429             } else {
430               lib.impl(
431                   name,
432                   torch::dispatch(
433                       dispatch,
434                       CppFunction::makeFromBoxedFunctor(
435                           std::make_unique<PythonKernelHolder>(
436                               func, dispatch, with_keyset))),
437                   register_or_verify());
438               python_registrations_[lib._resolve(name)].insert_or_assign(
439                   dispatch,
440                   std::make_shared<c10::SafePyObject>(
441                       func.release().ptr(), getPyInterpreter()));
442             }
443             END_HANDLE_TH_ERRORS_PYBIND
444           },
445           "",
446           py::arg("name"),
447           py::arg("dispatch"),
448           py::arg("func"),
449           py::arg("with_keyset") = false)
450       .def(
451           "define",
452           [](const py::object& self,
453              const char* schema,
454              const char* alias_analysis,
455              const std::vector<at::Tag>& tags) {
456             auto parsed_schema =
457                 torch::schema(schema, parseAliasAnalysisKind(alias_analysis));
458             self.cast<torch::Library&>().def(
459                 std::move(parsed_schema), tags, register_or_verify());
460             // TODO: this is dumb, had to make a second copy
461             return torch::schema(schema, parseAliasAnalysisKind(alias_analysis))
462                 .name();
463           },
464           "",
465           py::arg("schema"),
466           py::arg("alias_analysis") = "",
467           py::arg("tags") = std::vector<at::Tag>())
468       .def(
469           "fallback_fallthrough",
470           [](py::object self, const char* dispatch) {
471             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
472             self.cast<torch::Library&>().fallback(
473                 dispatch_str(dispatch, CppFunction::makeFallthrough()));
474             return self;
475           },
476           "",
477           py::arg("dispatch") = "")
478       .def(
479           "fallback",
480           [](const py::object& self,
481              c10::DispatchKey dispatch,
482              const py::object& func,
483              bool with_keyset) {
484             HANDLE_TH_ERRORS
485             auto& lib = self.cast<torch::Library&>();
486             TORCH_INTERNAL_ASSERT(isMainPyInterpreter());
487             if (func.is(py::module::import("torch.library")
488                             .attr("fallthrough_kernel"))) {
489               lib.fallback(
490                   torch::dispatch(dispatch, CppFunction::makeFallthrough()));
491             } else {
492               lib.fallback(torch::dispatch(
493                   dispatch,
494                   CppFunction::makeFromBoxedFunctor(
495                       std::make_unique<PythonKernelHolder>(
496                           func, dispatch, with_keyset, /*with_op*/ true))));
497             }
498             END_HANDLE_TH_ERRORS_PYBIND
499           },
500           "",
501           py::arg("dispatch"),
502           py::arg("func"),
503           py::arg("with_keyset") = false);
504 
505   m.def(
506       "_dispatch_library",
507       [](const char* kind,
508          std::string name,
509          const char* dispatch,
510          const char* file,
511          uint32_t linenum) {
512         HANDLE_TH_ERRORS
513         return std::make_unique<torch::Library>(
514             parseKind(kind),
515             std::move(name),
516             std::string(dispatch).empty()
517                 ? std::nullopt
518                 : std::make_optional(c10::parseDispatchKey(dispatch)),
519             "/dev/null", // temporary workaround
520             linenum);
521         END_HANDLE_TH_ERRORS_PYBIND
522       },
523       "",
524       py::arg("kind"),
525       py::arg("name"),
526       py::arg("dispatch"),
527       py::arg("file") = "/dev/null",
528       py::arg("linenum") = 0);
529 
530   m.def(
531       "_dispatch_find_schema_or_throw",
532       [](const char* name, const char* overload_name) -> c10::OperatorHandle {
533         return c10::Dispatcher::singleton().findSchemaOrThrow(
534             name, overload_name);
535       });
536 
537   m.def("_dispatch_dump", [](const char* name) -> std::string {
538     auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
539     if (!op) {
540       return "";
541     } else {
542       return op->dumpState();
543     }
544   });
545 
546   m.def("_dispatch_dump_table", [](const char* name) -> std::string {
547     auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
548     if (!op) {
549       return "";
550     } else {
551       return op->dumpComputedTable();
552     }
553   });
554 
555   m.def("_dispatch_check_invariants", [](const char* name) {
556     auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
557     if (!op) {
558     } else {
559       return op->checkInvariants();
560     }
561   });
562 
563   m.def("_dispatch_check_all_invariants", []() {
564     c10::Dispatcher::singleton().checkInvariants();
565   });
566 
567   m.def("_dispatch_has_kernel", [](const char* name) -> bool {
568     auto op = c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
569     return static_cast<bool>(op);
570   });
571 
572   m.def(
573       // Returns whether or not a direct kernel registration exists
574       // for this <op_name, dispatch_key> pair.
575       "_dispatch_has_kernel_for_dispatch_key",
576       [](const char* name, c10::DispatchKey dispatch) -> bool {
577         auto op =
578             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
579         TORCH_CHECK(op, "operator ", name, " does not exist");
580         return op->hasKernelForDispatchKey(dispatch);
581       });
582 
583   m.def(
584       // Returns whether or not the kernel for this dispatach key is a
585       // fallthrough kernel
586       "_dispatch_kernel_for_dispatch_key_is_fallthrough",
587       [](const char* name, c10::DispatchKey dispatch) -> bool {
588         auto op =
589             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
590         return op->isKernelFallthroughKernel(dispatch);
591       });
592 
593   m.def(
594       "_dispatch_has_kernel_for_any_dispatch_key",
595       [](const char* name, c10::DispatchKeySet ks) -> bool {
596         auto op =
597             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
598         TORCH_CHECK(op, "operator ", name, " does not exist");
599         return op->hasKernelForAnyDispatchKey(ks);
600       });
601 
602   m.def(
603       // Returns whether or not there is an entry in the runtime computed
604       // dispatch table, for this <op_name, dispatch_key> pair. For example, if
605       // "op" has a `CompositeImplicitAutograd` kernel, Then
606       // _dispatch_has_computed_kernel_for_dispatch_key(op, backend) will return
607       // true for all backends that are part of the alias set for
608       // CompositeImplicitAutograd.
609       "_dispatch_has_computed_kernel_for_dispatch_key",
610       [](const char* name, const char* dispatch) -> bool {
611         auto op =
612             c10::Dispatcher::singleton().findOp(torch::jit::parseName(name));
613         TORCH_CHECK(op, "operator ", name, " does not exist");
614         return op->hasComputedKernelForDispatchKey(
615             c10::parseDispatchKey(dispatch));
616       });
617 
618   m.def("_dispatch_find_dangling_impls", []() -> std::vector<std::string> {
619     auto danglingImpls = c10::Dispatcher::singleton().findDanglingImpls();
620 
621     std::vector<std::string> states;
622     states.reserve(danglingImpls.size());
623     for (auto& danglingImpl : danglingImpls) {
624       states.emplace_back(danglingImpl.dumpState());
625     }
626 
627     return states;
628   });
629 
630   m.def("_dispatch_get_all_op_names", []() -> std::vector<std::string> {
631     auto op_names = c10::Dispatcher::singleton().getAllOpNames();
632 
633     std::vector<std::string> names;
634     names.reserve(op_names.size());
635     for (auto& op : op_names) {
636       std::stringstream ss;
637       ss << op.name;
638       if (!op.overload_name.empty()) {
639         ss << "." << op.overload_name;
640       }
641       names.emplace_back(ss.str());
642     }
643 
644     return names;
645   });
646 
647   m.def(
648       "_dispatch_tls_set_dispatch_key_excluded",
649       [](c10::DispatchKey dispatch_key, bool desired_state) {
650         c10::impl::tls_set_dispatch_key_excluded(dispatch_key, desired_state);
651       });
652   m.def(
653       "_dispatch_tls_is_dispatch_key_excluded",
654       [](c10::DispatchKey dispatch_key) {
655         return c10::impl::tls_is_dispatch_key_excluded(dispatch_key);
656       });
657   m.def(
658       "_dispatch_tls_set_dispatch_key_included",
659       [](c10::DispatchKey dispatch_key, bool desired_state) {
660         c10::impl::tls_set_dispatch_key_included(dispatch_key, desired_state);
661       });
662   m.def(
663       "_dispatch_tls_is_dispatch_key_included",
664       [](c10::DispatchKey dispatch_key) {
665         return c10::impl::tls_is_dispatch_key_included(dispatch_key);
666       });
667 
668   m.def("_dispatch_isTensorSubclassLike", [](const at::Tensor& tensor) {
669     return at::isTensorSubclassLike(tensor);
670   });
671 
672   m.def("_dispatch_key_name", [](c10::DispatchKey k) {
673     return c10::toString(k);
674   });
675   m.def("_dispatch_key_parse", [](c10::DispatchKey k) { return k; });
676   m.def("_to_functionality_key", [](c10::DispatchKey k) {
677     return c10::toFunctionalityKey(k);
678   });
679   // E.g. given `DispatchKey::AutogradFunctionality`, returns a keyset of:
680   //  AutogradCPU
681   //  AutogradCUDA
682   //  ...
683   //  AutogradPrivateUse3
684   m.def("_functionality_to_backend_keys", [](c10::DispatchKey key) {
685     std::vector<c10::DispatchKey> keys;
686     if (c10::isPerBackendFunctionalityKey(key)) {
687       auto ks = c10::DispatchKeySet(key) |
688           c10::DispatchKeySet(c10::DispatchKeySet::RAW, c10::full_backend_mask);
689       for (auto k : ks) {
690         keys.push_back(k);
691       }
692     } else {
693       keys.push_back(key);
694     }
695     return keys;
696   });
697   m.def("_dispatch_num_backends", []() { return c10::num_backends; });
698 
699 #define DEF_ONE(n) .value(#n, c10::DispatchKey::n)
700 
701   py::enum_<c10::DispatchKey>(m, "DispatchKey")
702       // clang-format off
703       DEF_ONE(Undefined)
704       DEF_ONE(CompositeExplicitAutogradNonFunctional)
705       DEF_ONE(CompositeExplicitAutograd)
706       DEF_ONE(CompositeImplicitAutogradNestedTensor)
707       DEF_ONE(CompositeImplicitAutograd)
708       // NestedTensor is not a backend key
709       DEF_ONE(AutogradNestedTensor)
710       DEF_ONE(AutogradOther)
711       DEF_ONE(Autograd)
712       DEF_ONE(Conjugate)
713       DEF_ONE(ZeroTensor)
714       DEF_ONE(Negative)
715       DEF_ONE(BackendSelect)
716       DEF_ONE(ADInplaceOrView)
717       DEF_ONE(PythonTLSSnapshot)
718       DEF_ONE(Python)
719       DEF_ONE(FuncTorchDynamicLayerFrontMode)
720       DEF_ONE(FuncTorchDynamicLayerBackMode)
721       DEF_ONE(FuncTorchBatchedDecomposition)
722       DEF_ONE(FuncTorchBatched)
723       DEF_ONE(FuncTorchVmapMode)
724       DEF_ONE(FuncTorchGradWrapper)
725       DEF_ONE(PythonDispatcher)
726       DEF_ONE(PreDispatch)
727       DEF_ONE(Functionalize)
728       DEF_ONE(AutocastCPU)
729       DEF_ONE(AutocastMPS)
730       DEF_ONE(AutocastXPU)
731       DEF_ONE(AutocastHPU)
732       DEF_ONE(AutocastIPU)
733       DEF_ONE(AutocastCUDA)
734       DEF_ONE(AutocastPrivateUse1)
735   // clang-format on
736 
737 #define DEF_SINGLE(n, prefix) .value(#prefix #n, c10::DispatchKey::prefix##n)
738 #define DEF_MULTIPLE(fullname, prefix)              \
739   DEF_SINGLE(, fullname)                            \
740   DEF_SINGLE(, StartOf##fullname##Backends)         \
741   C10_FORALL_BACKEND_COMPONENTS(DEF_SINGLE, prefix) \
742   DEF_SINGLE(, EndOf##fullname##Backends)
743 
744       // clang-format off
745   C10_FORALL_FUNCTIONALITY_KEYS(DEF_MULTIPLE)
746   // clang-format on
747 
748 #undef DEF_MULTIPLE
749 #undef DEF_SINGLE
750           ;
751 
752   py::class_<c10::DispatchKeySet>(m, "DispatchKeySet")
753       .def(py::init<c10::DispatchKey>())
754       .def("__or__", &c10::DispatchKeySet::operator|)
755       .def("__sub__", &c10::DispatchKeySet::operator-)
756       .def("__and__", &c10::DispatchKeySet::operator&)
757       .def("raw_repr", &c10::DispatchKeySet::raw_repr)
758       .def("highestPriorityTypeId", &c10::DispatchKeySet::highestPriorityTypeId)
759       .def(
760           "remove",
761           [](c10::DispatchKeySet self, c10::DispatchKey k) {
762             return self.remove(k);
763           })
764       .def(
765           "add",
766           [](c10::DispatchKeySet self, c10::DispatchKey k) {
767             return self.add(k);
768           })
769       .def("has", &c10::DispatchKeySet::has)
770       .def("__repr__", [](c10::DispatchKeySet d) { return c10::toString(d); });
771 
772   m.attr("_dispatch_autogradother_backends") =
773       py::cast(c10::autogradother_backends);
774 
775   m.attr("_additional_keys_to_prop_for_wrapper_tensors") =
776       py::cast(at::functorch::kKeysToPropagateToWrapper);
777 
778   m.attr("_after_autograd_keyset") = py::cast(c10::after_autograd_keyset);
779   m.attr("_after_ADInplaceOrView_keyset") =
780       py::cast(c10::after_ADInplaceOrView_keyset);
781 
782   m.def("_dispatch_has_backend_fallback", [](c10::DispatchKey t) {
783     return c10::Dispatcher::singleton().hasBackendFallbackForDispatchKey(t);
784   });
785 
786   m.def("_dispatch_keyset_full_after", [](c10::DispatchKey t) {
787     return c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, t);
788   });
789 
790   m.def("_dispatch_keyset_full", []() {
791     return c10::DispatchKeySet(c10::DispatchKeySet::FULL);
792   });
793 
794   m.def("_dispatch_is_alias_key", c10::isAliasDispatchKey);
795 
796   m.def("_dispatch_keyset_to_string", [](c10::DispatchKeySet keyset) {
797     return c10::toString(keyset);
798   });
799 
800   m.def("_dispatch_get_backend_keyset_from_autograd", [](c10::DispatchKey k) {
801     return c10::getBackendKeySetFromAutograd(k);
802   });
803 
804   m.def("_dispatch_keys", [](const at::Tensor& tensor) {
805     auto* impl = tensor.unsafeGetTensorImpl();
806     return impl->key_set();
807   });
808   m.def("_dispatch_tls_local_include_set", []() {
809     return c10::impl::tls_local_dispatch_key_set().included_;
810   });
811   m.def("_dispatch_tls_local_exclude_set", []() {
812     return c10::impl::tls_local_dispatch_key_set().excluded_;
813   });
814   m.def("_functionalization_reapply_views_tls", []() {
815     return at::functionalization::impl::getFunctionalizationReapplyViewsTLS();
816   });
817   m.def(
818       "_dispatch_is_included_in_alias",
819       [](c10::DispatchKey a, c10::DispatchKey b) {
820         return c10::isIncludedInAlias(a, b);
821       });
822 
823   // DEPRECATED, please don't use this. Instead use
824   // torch._C._ExcludeDispatchKeyGuard
825   py_context_manager_DEPRECATED<
826       c10::impl::ExcludeDispatchKeyGuard,
827       c10::DispatchKeySet>(m, "ExcludeDispatchKeyGuard");
828 
829   py_context_manager<
830       c10::impl::ForceDispatchKeyGuard,
831       c10::DispatchKeySet,
832       c10::DispatchKeySet>(m, "_ForceDispatchKeyGuard");
833   py_context_manager<c10::impl::ForceDispatchKeyGuard>(
834       m, "_PreserveDispatchKeyGuard");
835   py_context_manager<c10::impl::IncludeDispatchKeyGuard, c10::DispatchKey>(
836       m, "_IncludeDispatchKeyGuard");
837   py_context_manager<c10::impl::ExcludeDispatchKeyGuard, c10::DispatchKeySet>(
838       m, "_ExcludeDispatchKeyGuard");
839   py_context_manager<SetExcludeDispatchKeyGuard, c10::DispatchKey, bool>(
840       m, "_SetExcludeDispatchKeyGuard");
841 
842   py_context_manager_DEPRECATED<at::AutoDispatchBelowAutograd>(
843       m, "_AutoDispatchBelowAutograd");
844   py_context_manager<at::AutoDispatchBelowADInplaceOrView>(
845       m, "_AutoDispatchBelowADInplaceOrView");
846 
847   // Prints out the name of every operator that has a kernel registered to the
848   // Dispatcher under [dispatch_key]. If no arguments are specified, it'll print
849   // out the name of every operator that the Dispatcher knows of. This can be
850   // useful to answer questions like "list all operators that do not have a CPU
851   // kernel".
852   m.def(
853       "_dispatch_print_registrations_for_dispatch_key",
854       [](const char* dispatch_key = "") {
855         auto k = std::string(dispatch_key).empty()
856             ? std::nullopt
857             : std::make_optional(c10::parseDispatchKey(dispatch_key));
858         auto op_names =
859             c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
860         for (auto& op : op_names) {
861           std::cout << op << '\n';
862         }
863       },
864       py::arg("dispatch_key") = static_cast<const char*>(""));
865 
866   m.def(
867       "_parse_dispatch_key",
868       [](const char* dispatch_key) -> std::optional<c10::DispatchKey> {
869         try {
870           return c10::parseDispatchKey(dispatch_key);
871         } catch (const c10::Error& err) {
872           return std::nullopt;
873         }
874       });
875 
876   m.def(
877       "_dispatch_get_registrations_for_dispatch_key",
878       [](const char* dispatch_key = "") {
879         auto k = std::string(dispatch_key).empty()
880             ? std::nullopt
881             : std::make_optional(c10::parseDispatchKey(dispatch_key));
882         auto op_names =
883             c10::Dispatcher::singleton().getRegistrationsForDispatchKey(k);
884         std::vector<std::string> names;
885         names.reserve(op_names.size());
886         for (auto& op : op_names) {
887           names.emplace_back(
888               op.name +
889               (op.overload_name.empty() ? "" : "." + op.overload_name));
890         }
891         return names;
892       },
893       py::arg("dispatch_key") = static_cast<const char*>(""));
894   m.def(
895       "_dispatch_set_report_error_callback",
896       [](c10::OperatorHandle& handle, py::object callback) {
897         auto obj = callback.release().ptr();
898         auto callback_obj =
899             std::make_unique<c10::SafePyObject>(obj, getPyInterpreter());
900         handle.setReportErrorCallback_(std::move(callback_obj));
901       });
902 
903   m.def(
904       "_dispatch_is_main_interpreter", []() { return isMainPyInterpreter(); });
905   m.def("_dispatch_pystub", [](const char* name, const char* overload) {
906     return c10::Dispatcher::singleton().getPyStub(
907         c10::OperatorName(name, overload));
908   });
909 
910   m.def("_replace_", [](const at::Tensor& a, const at::Tensor& b) {
911     return at::functionalization::impl::replace_(a, b);
912   });
913   m.def("_propagate_xla_data", [](const at::Tensor& a, const at::Tensor& b) {
914     at::functionalization::impl::propagate_xla_data(a, b);
915   });
916   m.def("_commit_update", [](const at::Tensor& a) {
917     return at::functionalization::impl::commit_update(a);
918   });
919   m.def("_unsafe_reset_storage", [](const at::Tensor& a) {
920     return at::functionalization::impl::unsafe_reset_storage(a);
921   });
922 
923   m.def("_dispatch_key_for_device", [](const std::string& device_type) {
924     auto device = c10::Device(device_type);
925     TORCH_CHECK(
926         !device.has_index(),
927         "Expected device_type string to not have a device index; got ",
928         device_type);
929     return c10::toString(
930         c10::computeDispatchKey(std::nullopt, std::nullopt, device));
931   });
932 
933   m.def("_are_functorch_transforms_active", []() {
934     auto include_set = c10::impl::tls_local_dispatch_key_set().included_;
935     return (
936         include_set.has(c10::DispatchKey::FuncTorchDynamicLayerFrontMode) ||
937         include_set.has(c10::DispatchKey::FuncTorchDynamicLayerBackMode));
938   });
939 
940   m.def("_get_nested_int", [](int64_t data, int64_t coeff) {
941     return c10::SymInt(c10::SymNode(
942         c10::make_intrusive<c10::NestedIntSymNodeImpl>(data, coeff)));
943   });
944 
945   m.def("_get_constant_bool_symnode", [](int64_t data) {
946     return c10::SymNode(
947         c10::make_intrusive<c10::ConstantSymNodeImpl<bool>>(data));
948   });
949 
950   m.def("_non_sym_sizes", [](const at::Tensor& a) {
951     return a.sizes(); // NB: NOT sym_size
952   });
953 
954   m.def("_set_throw_on_mutable_data_ptr", [](const at::Tensor& t) {
955     if (!t.unsafeGetTensorImpl()->has_storage()) {
956       // If the Tensor doesn't have a storage, then accessing .data_ptr()
957       // will already raise an error.
958       return;
959     }
960     // Otherwise, set (on the StorageImpl) that accessing (mutable) data_ptr
961     // will throw.
962     t.unsafeGetTensorImpl()
963         ->storage()
964         .unsafeGetStorageImpl()
965         ->set_throw_on_mutable_data_ptr();
966   });
967 
968   // Invariant: you must ONLY call this with FakeTensors.
969   m.def("_set_warn_deprecated_on_mutable_data_ptr", [](const at::Tensor& t) {
970     if (!t.unsafeGetTensorImpl()->has_storage()) {
971       // If the Tensor doesn't have a storage, then accessing .data_ptr()
972       // will already raise an error.
973       return;
974     }
975     t.unsafeGetTensorImpl()
976         ->storage()
977         .unsafeGetStorageImpl()
978         ->set_warn_deprecated_on_mutable_data_ptr();
979   });
980 
981   m.def("_only_lift_cpu_tensors", &torch::utils::only_lift_cpu_tensors);
982   m.def("_set_only_lift_cpu_tensors", &torch::utils::set_only_lift_cpu_tensors);
983 
984   using c10::impl::TorchDispatchModeKey;
985   py::enum_<TorchDispatchModeKey>(m, "_TorchDispatchModeKey")
986       .value("FUNCTIONAL", TorchDispatchModeKey::FUNCTIONAL)
987       .value("PROXY", TorchDispatchModeKey::PROXY)
988       .value("FAKE", TorchDispatchModeKey::FAKE);
989 }
990 
991 // TODO: dedupe with the kernel
python_op_registration_trampoline_impl(const c10::OperatorHandle & op,c10::DispatchKey key,c10::DispatchKeySet keyset,torch::jit::Stack * stack,bool with_keyset,bool with_op)992 void python_op_registration_trampoline_impl(
993     const c10::OperatorHandle& op,
994     c10::DispatchKey key,
995     c10::DispatchKeySet keyset,
996     torch::jit::Stack* stack,
997     bool with_keyset,
998     bool with_op) {
999   auto arguments = torch::jit::pop(*stack, op.schema().arguments().size());
1000   py::gil_scoped_acquire g;
1001   auto args_kwargs = parseIValuesToPyArgsKwargs(op, arguments);
1002   const auto& func = python_registrations_[op.operator_name()][key];
1003   TORCH_INTERNAL_ASSERT(func != nullptr);
1004   auto* pyobj = func->ptr(getPyInterpreter());
1005   TORCH_INTERNAL_ASSERT(pyobj != nullptr);
1006   auto callable = py::reinterpret_borrow<py::object>(pyobj);
1007   auto obj = with_op ? with_keyset ? callable(
1008                                          keyset,
1009                                          torch::detail::getTorchApiFunction(op),
1010                                          *args_kwargs.first,
1011                                          **args_kwargs.second)
1012                                    : callable(
1013                                          torch::detail::getTorchApiFunction(op),
1014                                          *args_kwargs.first,
1015                                          **args_kwargs.second)
1016       : with_keyset ? callable(keyset, *args_kwargs.first, **args_kwargs.second)
1017                     : callable(*args_kwargs.first, **args_kwargs.second);
1018   if (!obj) {
1019     throw python_error();
1020   }
1021   pushPyOutToStack(op, stack, obj, "PythonKernelHolder");
1022 }
1023 
1024 } // namespace torch::impl::dispatch
1025