xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/disable_torch_function.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <c10/core/DispatchKey.h>
3 #include <c10/core/impl/LocalDispatchKeySet.h>
4 #include <torch/csrc/python_headers.h>
5 
6 namespace torch {
7 // Sometimes we don't want infinite recursion for subclasses,
8 // Or a way to achieve the old behaviour.
9 
10 // This is an internal utility, not exposed to users.
11 bool torch_function_enabled();
12 PyObject* disabled_torch_function_impl();
13 PyObject* disabled_torch_dispatch_impl();
14 void set_disabled_torch_function_impl(PyObject* value);
15 void set_disabled_torch_dispatch_impl(PyObject* value);
16 // Set ignore_mode to true if you're trying to collect overloaded arguments;
17 // using mode here will improperly cause you to add ALL objects to the
18 // overloaded list even if they don't actually have __torch_function__
19 bool check_has_torch_function(PyObject* obj, bool ignore_mode = false);
20 
21 struct DisableTorchDispatch {
DisableTorchDispatchDisableTorchDispatch22   DisableTorchDispatch()
23       : guard_(c10::DispatchKeySet(
24             {c10::DispatchKey::Python, c10::DispatchKey::PreDispatch})),
25         guard_tls_snapshot_(c10::DispatchKey::PythonTLSSnapshot) {}
26   c10::impl::ExcludeDispatchKeyGuard guard_;
27   c10::impl::ExcludeDispatchKeyGuard guard_tls_snapshot_;
28 };
29 
30 } // namespace torch
31 
32 PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused);
33 PyObject* THPModule_isAllDisabledTorchFunction(
34     PyObject* self,
35     PyObject* unused);
36 PyObject* THPModule_DisableTorchFunctionType();
37 PyObject* THPModule_DisableTorchFunctionSubclassType();
38 PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* args);
39 PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* args);
40 PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg);
41 PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject* obj);
42 PyObject* THPModule_has_torch_function_variadic(
43     PyObject*,
44     PyObject* const* args,
45     Py_ssize_t nargs);
46