xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/disable_torch_function.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/Exceptions.h>
2 #include <torch/csrc/autograd/python_variable.h>
3 #include <torch/csrc/utils/disable_torch_function.h>
4 #include <torch/csrc/utils/pybind.h>
5 #include <torch/csrc/utils/python_strings.h>
6 
7 #include <ATen/PythonTorchFunctionTLS.h>
8 
9 namespace torch {
10 PyObject* disabled_torch_function = nullptr;
11 PyObject* disabled_torch_dispatch = nullptr;
12 
torch_function_enabled()13 bool torch_function_enabled() {
14   return at::impl::PythonTorchFunctionTLS::get_disabled_state() ==
15       at::impl::TorchFunctionDisabledState::ENABLED;
16 }
17 
disabled_torch_function_impl()18 PyObject* disabled_torch_function_impl() {
19   return disabled_torch_function;
20 }
21 
set_disabled_torch_function_impl(PyObject * value)22 void set_disabled_torch_function_impl(PyObject* value) {
23   disabled_torch_function = value;
24 }
25 
disabled_torch_dispatch_impl()26 PyObject* disabled_torch_dispatch_impl() {
27   return disabled_torch_dispatch;
28 }
29 
set_disabled_torch_dispatch_impl(PyObject * value)30 void set_disabled_torch_dispatch_impl(PyObject* value) {
31   disabled_torch_dispatch = value;
32 }
33 } // namespace torch
34 
35 typedef struct {
36   PyObject_HEAD
37       /* Type-specific fields go here. */
38       at::impl::TorchFunctionDisabledState old_state;
39 } DisableTorchFunctionSubclass;
40 
DisableTorchFunctionSubclass__enter(PyObject * self,PyObject * unused)41 PyObject* DisableTorchFunctionSubclass__enter(
42     PyObject* self,
43     PyObject* unused) {
44   const auto old_state = at::impl::PythonTorchFunctionTLS::get_disabled_state();
45   ((DisableTorchFunctionSubclass*)self)->old_state = old_state;
46   if (old_state == at::impl::TorchFunctionDisabledState::ENABLED) {
47     at::impl::PythonTorchFunctionTLS::set_disabled_state(
48         at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED);
49   }
50   Py_RETURN_NONE;
51 }
52 
DisableTorchFunctionSubclass__exit(PyObject * self,PyObject * unused)53 PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) {
54   at::impl::PythonTorchFunctionTLS::set_disabled_state(
55       ((DisableTorchFunctionSubclass*)self)->old_state);
56   Py_RETURN_NONE;
57 }
58 
THPModule_isEnabledTorchFunction(PyObject * self,PyObject * unused)59 PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) {
60   if (torch::torch_function_enabled()) {
61     Py_RETURN_TRUE;
62   } else {
63     Py_RETURN_FALSE;
64   }
65 }
66 
THPModule_isAllDisabledTorchFunction(PyObject * self,PyObject * unused)67 PyObject* THPModule_isAllDisabledTorchFunction(
68     PyObject* self,
69     PyObject* unused) {
70   if (at::impl::torch_function_all_disabled()) {
71     Py_RETURN_TRUE;
72   } else {
73     Py_RETURN_FALSE;
74   }
75 }
76 
77 static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT
78     {"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr},
79     {"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr},
80     {nullptr, nullptr, 0, nullptr}};
81 
82 PyTypeObject DisableTorchFunctionSubclassType = {
83     PyVarObject_HEAD_INIT(
84         nullptr,
85         0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */
86     sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */
87     0, /* tp_itemsize */
88     nullptr, /* tp_dealloc */
89     0, /* tp_vectorcall_offset */
90     nullptr, /* tp_getattr */
91     nullptr, /* tp_setattr */
92     nullptr, /* tp_reserved */
93     nullptr, /* tp_repr */
94     nullptr, /* tp_as_number */
95     nullptr, /* tp_as_sequence */
96     nullptr, /* tp_as_mapping */
97     nullptr, /* tp_hash  */
98     nullptr, /* tp_call */
99     nullptr, /* tp_str */
100     nullptr, /* tp_getattro */
101     nullptr, /* tp_setattro */
102     nullptr, /* tp_as_buffer */
103     Py_TPFLAGS_DEFAULT, /* tp_flags */
104     nullptr, /* tp_doc */
105     nullptr, /* tp_traverse */
106     nullptr, /* tp_clear */
107     nullptr, /* tp_richcompare */
108     0, /* tp_weaklistoffset */
109     nullptr, /* tp_iter */
110     nullptr, /* tp_iternext */
111     DisableTorchFunctionSubclass_methods, /* tp_methods */
112     nullptr, /* tp_members */
113     nullptr, /* tp_getset */
114     nullptr, /* tp_base */
115     nullptr, /* tp_dict */
116     nullptr, /* tp_descr_get */
117     nullptr, /* tp_descr_set */
118     0, /* tp_dictoffset */
119     nullptr, /* tp_init */
120     PyType_GenericAlloc, /* tp_alloc */
121     PyType_GenericNew, /* tp_new */
122 };
123 
THPModule_DisableTorchFunctionSubclassType()124 PyObject* THPModule_DisableTorchFunctionSubclassType() {
125   if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) {
126     return nullptr;
127   }
128 
129   return (PyObject*)(&DisableTorchFunctionSubclassType);
130 }
131 
132 typedef struct {
133   PyObject_HEAD
134       /* Type-specific fields go here. */
135       at::impl::TorchFunctionDisabledState old_state;
136 } DisableTorchFunction;
137 
DisableTorchFunction__enter(PyObject * self,PyObject * unused)138 PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) {
139   ((DisableTorchFunctionSubclass*)self)->old_state =
140       at::impl::PythonTorchFunctionTLS::get_disabled_state();
141   at::impl::PythonTorchFunctionTLS::set_disabled_state(
142       at::impl::TorchFunctionDisabledState::ALL_DISABLED);
143   Py_RETURN_NONE;
144 }
145 
DisableTorchFunction__exit(PyObject * self,PyObject * unused)146 PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) {
147   at::impl::PythonTorchFunctionTLS::set_disabled_state(
148       ((DisableTorchFunctionSubclass*)self)->old_state);
149   Py_RETURN_NONE;
150 }
151 
152 static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT
153     {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr},
154     {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr},
155     {nullptr, nullptr, 0, nullptr}};
156 
157 PyTypeObject DisableTorchFunctionType = {
158     PyVarObject_HEAD_INIT(
159         nullptr,
160         0) "torch._C.DisableTorchFunction", /* tp_name */
161     sizeof(DisableTorchFunction), /* tp_basicsize */
162     0, /* tp_itemsize */
163     nullptr, /* tp_dealloc */
164     0, /* tp_vectorcall_offset */
165     nullptr, /* tp_getattr */
166     nullptr, /* tp_setattr */
167     nullptr, /* tp_reserved */
168     nullptr, /* tp_repr */
169     nullptr, /* tp_as_number */
170     nullptr, /* tp_as_sequence */
171     nullptr, /* tp_as_mapping */
172     nullptr, /* tp_hash  */
173     nullptr, /* tp_call */
174     nullptr, /* tp_str */
175     nullptr, /* tp_getattro */
176     nullptr, /* tp_setattro */
177     nullptr, /* tp_as_buffer */
178     Py_TPFLAGS_DEFAULT, /* tp_flags */
179     nullptr, /* tp_doc */
180     nullptr, /* tp_traverse */
181     nullptr, /* tp_clear */
182     nullptr, /* tp_richcompare */
183     0, /* tp_weaklistoffset */
184     nullptr, /* tp_iter */
185     nullptr, /* tp_iternext */
186     DisableTorchFunction_methods, /* tp_methods */
187     nullptr, /* tp_members */
188     nullptr, /* tp_getset */
189     nullptr, /* tp_base */
190     nullptr, /* tp_dict */
191     nullptr, /* tp_descr_get */
192     nullptr, /* tp_descr_set */
193     0, /* tp_dictoffset */
194     nullptr, /* tp_init */
195     PyType_GenericAlloc, /* tp_alloc */
196     PyType_GenericNew, /* tp_new */
197 };
198 
THPModule_DisableTorchFunctionType()199 PyObject* THPModule_DisableTorchFunctionType() {
200   if (PyType_Ready(&DisableTorchFunctionType) < 0) {
201     return nullptr;
202   }
203 
204   return (PyObject*)(&DisableTorchFunctionType);
205 }
206 
THPModule_disable_torch_function(PyObject * self,PyObject * a)207 PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) {
208   HANDLE_TH_ERRORS
209   PyObject *func = nullptr, *types = nullptr, *args = nullptr,
210            *kwargs = nullptr;
211   if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) {
212     return nullptr;
213   }
214   py::tuple py_args;
215   if (args == nullptr) {
216     py_args = py::make_tuple();
217   } else if (PyList_Check(args)) {
218     py_args = py::reinterpret_steal<py::tuple>(PyList_AsTuple(args));
219   } else if (PyTuple_Check(args)) {
220     py_args = py::reinterpret_borrow<py::tuple>(args);
221   } else {
222     throw torch::TypeError(
223         "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name);
224   }
225 
226   // These are all C-API calls so no exceptions will be raised
227   // and therefore no need for RAII approach to storing
228   // the old value.
229   auto old_value = at::impl::PythonTorchFunctionTLS::get_disabled_state();
230   if (old_value == at::impl::TorchFunctionDisabledState::ENABLED) {
231     at::impl::PythonTorchFunctionTLS::set_disabled_state(
232         at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED);
233   }
234   // kwargs can safely be nullptr here.
235   PyObject* result = PyObject_Call(func, py_args.ptr(), kwargs);
236   at::impl::PythonTorchFunctionTLS::set_disabled_state(old_value);
237   return result;
238   END_HANDLE_TH_ERRORS
239 }
240 
THPModule_disable_torch_dispatch(PyObject * self,PyObject * a)241 PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* a) {
242   HANDLE_TH_ERRORS
243   PyObject *func = nullptr, *types = nullptr, *args = nullptr,
244            *kwargs = nullptr;
245   if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) {
246     return nullptr;
247   }
248   py::tuple py_args;
249   if (args == nullptr) {
250     py_args = py::make_tuple();
251   } else if (PyList_Check(args)) {
252     py_args = py::reinterpret_steal<py::tuple>(PyList_AsTuple(args));
253   } else if (PyTuple_Check(args)) {
254     py_args = py::reinterpret_borrow<py::tuple>(args);
255   } else {
256     throw torch::TypeError(
257         "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name);
258   }
259 
260   // This implementation is not completely correct.  The moral
261   // meaning of this function is that we should do a redispatch
262   // "after" PythonKey, aka a redispatch() call.  But we don't have a
263   // dispatcher call here; we have an opaque Python object.
264   //
265   // What we have here is a close approximation: instead of redispatch(), we
266   // just exclude Python and all the keys before it, so that we will go
267   // to the next key after Python.  The difference, however, is we are
268   // now PERMANENTLY after Python.  We don't think there are any legitimate
269   // cases where we want to go for another round on the entire dispatcher key
270   // set, but if there are, then we will have to do something else here.
271   c10::impl::ExcludeDispatchKeyGuard guard_(
272       // TODO: add constructor for this specifically
273       c10::DispatchKeySet(c10::DispatchKeySet::FULL) -
274       c10::DispatchKeySet(
275           c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python)
276       // NB: off by one hazard here, but it works out: python key is not
277       // included in AFTER, so it is included in the negation (and that's
278       // correct: we want to exclude Python key and everything BEFORE it.)
279   );
280   auto r = PyObject_Call(func, py_args.ptr(), kwargs);
281   if (r == nullptr)
282     throw python_error();
283   return r;
284   END_HANDLE_TH_ERRORS
285 }
286 
287 // Makes sure that we don't check for __torch_function__ on basic Python types
is_basic_python_type(PyTypeObject * tp)288 static bool is_basic_python_type(PyTypeObject* tp) {
289   return (
290       /* Basic number types */
291       tp == &PyBool_Type ||
292 
293       tp == &PyLong_Type || tp == &PyFloat_Type || tp == &PyComplex_Type ||
294 
295       /* Basic sequence types */
296       tp == &PyList_Type || tp == &PyTuple_Type || tp == &PyDict_Type ||
297       tp == &PySet_Type || tp == &PyFrozenSet_Type || tp == &PyUnicode_Type ||
298       tp == &PyBytes_Type ||
299 
300       /* other builtins */
301       tp == &PySlice_Type || tp == Py_TYPE(Py_None) ||
302       tp == Py_TYPE(Py_Ellipsis) || tp == Py_TYPE(Py_NotImplemented) ||
303 
304       PyModule_Check(tp) ||
305       /* sentinel to swallow trailing || */
306       false);
307 }
308 
has_torch_function_attr(PyObject * obj)309 inline bool has_torch_function_attr(PyObject* obj) {
310   auto attr = PyObject_FastGetAttrString(obj, "__torch_function__");
311   return (
312       attr.ptr() != nullptr && attr.ptr() != torch::disabled_torch_function);
313 }
314 
315 namespace torch {
check_has_torch_function(PyObject * obj,bool ignore_mode)316 auto check_has_torch_function(PyObject* obj, bool ignore_mode) -> bool {
317   if (!ignore_mode && at::impl::torch_function_mode_enabled())
318     return true;
319   PyTypeObject* tp = Py_TYPE(obj);
320   return (
321       !THPVariable_CheckTypeExact(tp) && !is_basic_python_type(tp) &&
322       torch::torch_function_enabled() && has_torch_function_attr(obj));
323 }
324 } // namespace torch
325 
sequence_has_torch_function(PyObject * args)326 inline bool sequence_has_torch_function(PyObject* args) {
327   // NOLINTNEXTLINE(bugprone-branch-clone)
328   Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args);
329   for (Py_ssize_t i = 0; i < nargs; i++) {
330     PyObject* obj = PySequence_Fast_GET_ITEM(args, i);
331     if (torch::check_has_torch_function(obj)) {
332       return true;
333     }
334   }
335   return false;
336 }
337 
array_has_torch_function(PyObject * const * args,Py_ssize_t nargs)338 inline bool array_has_torch_function(PyObject* const* args, Py_ssize_t nargs) {
339   for (Py_ssize_t i = 0; i < nargs; i++) {
340     if (torch::check_has_torch_function(args[i])) {
341       return true;
342     }
343   }
344   return false;
345 }
346 
THPModule_has_torch_function(PyObject *,PyObject * arg)347 PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg) {
348   bool result; // NOLINT(cppcoreguidelines-init-variables)
349   if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) {
350     // Fast path:
351     //   If we know that we have a tuple or list, we can skip an INCREF and
352     //   DECREF from PySequence_Fast. Core functions will always follow this
353     //   convention (almost always tuples), and it shaves ~3.5% off the cost of
354     //   the check.
355     result = sequence_has_torch_function(arg);
356   } else {
357     auto args = py::reinterpret_steal<py::object>(
358         PySequence_Fast(arg, "expected a sequence"));
359     if (!args) {
360       return nullptr;
361     }
362     result = sequence_has_torch_function(args.ptr());
363   }
364 
365   if (result) {
366     Py_RETURN_TRUE;
367   }
368   Py_RETURN_FALSE;
369 }
370 
THPModule_has_torch_function_unary(PyObject *,PyObject * obj)371 PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject* obj) {
372   // Special case `THPModule_has_torch_function` for the single arg case.
373   if (torch::check_has_torch_function(obj)) {
374     Py_RETURN_TRUE;
375   }
376   Py_RETURN_FALSE;
377 }
378 
THPModule_has_torch_function_variadic(PyObject *,PyObject * const * args,Py_ssize_t nargs)379 PyObject* THPModule_has_torch_function_variadic(
380     PyObject*,
381     PyObject* const* args,
382     Py_ssize_t nargs) {
383   if (array_has_torch_function(args, nargs)) {
384     Py_RETURN_TRUE;
385   }
386   Py_RETURN_FALSE;
387 }
388