1 #include <torch/csrc/Exceptions.h>
2 #include <torch/csrc/autograd/python_variable.h>
3 #include <torch/csrc/utils/disable_torch_function.h>
4 #include <torch/csrc/utils/pybind.h>
5 #include <torch/csrc/utils/python_strings.h>
6
7 #include <ATen/PythonTorchFunctionTLS.h>
8
9 namespace torch {
10 PyObject* disabled_torch_function = nullptr;
11 PyObject* disabled_torch_dispatch = nullptr;
12
torch_function_enabled()13 bool torch_function_enabled() {
14 return at::impl::PythonTorchFunctionTLS::get_disabled_state() ==
15 at::impl::TorchFunctionDisabledState::ENABLED;
16 }
17
disabled_torch_function_impl()18 PyObject* disabled_torch_function_impl() {
19 return disabled_torch_function;
20 }
21
set_disabled_torch_function_impl(PyObject * value)22 void set_disabled_torch_function_impl(PyObject* value) {
23 disabled_torch_function = value;
24 }
25
disabled_torch_dispatch_impl()26 PyObject* disabled_torch_dispatch_impl() {
27 return disabled_torch_dispatch;
28 }
29
set_disabled_torch_dispatch_impl(PyObject * value)30 void set_disabled_torch_dispatch_impl(PyObject* value) {
31 disabled_torch_dispatch = value;
32 }
33 } // namespace torch
34
35 typedef struct {
36 PyObject_HEAD
37 /* Type-specific fields go here. */
38 at::impl::TorchFunctionDisabledState old_state;
39 } DisableTorchFunctionSubclass;
40
DisableTorchFunctionSubclass__enter(PyObject * self,PyObject * unused)41 PyObject* DisableTorchFunctionSubclass__enter(
42 PyObject* self,
43 PyObject* unused) {
44 const auto old_state = at::impl::PythonTorchFunctionTLS::get_disabled_state();
45 ((DisableTorchFunctionSubclass*)self)->old_state = old_state;
46 if (old_state == at::impl::TorchFunctionDisabledState::ENABLED) {
47 at::impl::PythonTorchFunctionTLS::set_disabled_state(
48 at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED);
49 }
50 Py_RETURN_NONE;
51 }
52
DisableTorchFunctionSubclass__exit(PyObject * self,PyObject * unused)53 PyObject* DisableTorchFunctionSubclass__exit(PyObject* self, PyObject* unused) {
54 at::impl::PythonTorchFunctionTLS::set_disabled_state(
55 ((DisableTorchFunctionSubclass*)self)->old_state);
56 Py_RETURN_NONE;
57 }
58
THPModule_isEnabledTorchFunction(PyObject * self,PyObject * unused)59 PyObject* THPModule_isEnabledTorchFunction(PyObject* self, PyObject* unused) {
60 if (torch::torch_function_enabled()) {
61 Py_RETURN_TRUE;
62 } else {
63 Py_RETURN_FALSE;
64 }
65 }
66
THPModule_isAllDisabledTorchFunction(PyObject * self,PyObject * unused)67 PyObject* THPModule_isAllDisabledTorchFunction(
68 PyObject* self,
69 PyObject* unused) {
70 if (at::impl::torch_function_all_disabled()) {
71 Py_RETURN_TRUE;
72 } else {
73 Py_RETURN_FALSE;
74 }
75 }
76
77 static PyMethodDef DisableTorchFunctionSubclass_methods[] = { // NOLINT
78 {"__enter__", DisableTorchFunctionSubclass__enter, METH_NOARGS, nullptr},
79 {"__exit__", DisableTorchFunctionSubclass__exit, METH_VARARGS, nullptr},
80 {nullptr, nullptr, 0, nullptr}};
81
82 PyTypeObject DisableTorchFunctionSubclassType = {
83 PyVarObject_HEAD_INIT(
84 nullptr,
85 0) "torch._C.DisableTorchFunctionSubclass", /* tp_name */
86 sizeof(DisableTorchFunctionSubclass), /* tp_basicsize */
87 0, /* tp_itemsize */
88 nullptr, /* tp_dealloc */
89 0, /* tp_vectorcall_offset */
90 nullptr, /* tp_getattr */
91 nullptr, /* tp_setattr */
92 nullptr, /* tp_reserved */
93 nullptr, /* tp_repr */
94 nullptr, /* tp_as_number */
95 nullptr, /* tp_as_sequence */
96 nullptr, /* tp_as_mapping */
97 nullptr, /* tp_hash */
98 nullptr, /* tp_call */
99 nullptr, /* tp_str */
100 nullptr, /* tp_getattro */
101 nullptr, /* tp_setattro */
102 nullptr, /* tp_as_buffer */
103 Py_TPFLAGS_DEFAULT, /* tp_flags */
104 nullptr, /* tp_doc */
105 nullptr, /* tp_traverse */
106 nullptr, /* tp_clear */
107 nullptr, /* tp_richcompare */
108 0, /* tp_weaklistoffset */
109 nullptr, /* tp_iter */
110 nullptr, /* tp_iternext */
111 DisableTorchFunctionSubclass_methods, /* tp_methods */
112 nullptr, /* tp_members */
113 nullptr, /* tp_getset */
114 nullptr, /* tp_base */
115 nullptr, /* tp_dict */
116 nullptr, /* tp_descr_get */
117 nullptr, /* tp_descr_set */
118 0, /* tp_dictoffset */
119 nullptr, /* tp_init */
120 PyType_GenericAlloc, /* tp_alloc */
121 PyType_GenericNew, /* tp_new */
122 };
123
THPModule_DisableTorchFunctionSubclassType()124 PyObject* THPModule_DisableTorchFunctionSubclassType() {
125 if (PyType_Ready(&DisableTorchFunctionSubclassType) < 0) {
126 return nullptr;
127 }
128
129 return (PyObject*)(&DisableTorchFunctionSubclassType);
130 }
131
132 typedef struct {
133 PyObject_HEAD
134 /* Type-specific fields go here. */
135 at::impl::TorchFunctionDisabledState old_state;
136 } DisableTorchFunction;
137
DisableTorchFunction__enter(PyObject * self,PyObject * unused)138 PyObject* DisableTorchFunction__enter(PyObject* self, PyObject* unused) {
139 ((DisableTorchFunctionSubclass*)self)->old_state =
140 at::impl::PythonTorchFunctionTLS::get_disabled_state();
141 at::impl::PythonTorchFunctionTLS::set_disabled_state(
142 at::impl::TorchFunctionDisabledState::ALL_DISABLED);
143 Py_RETURN_NONE;
144 }
145
DisableTorchFunction__exit(PyObject * self,PyObject * unused)146 PyObject* DisableTorchFunction__exit(PyObject* self, PyObject* unused) {
147 at::impl::PythonTorchFunctionTLS::set_disabled_state(
148 ((DisableTorchFunctionSubclass*)self)->old_state);
149 Py_RETURN_NONE;
150 }
151
152 static PyMethodDef DisableTorchFunction_methods[] = { // NOLINT
153 {"__enter__", DisableTorchFunction__enter, METH_NOARGS, nullptr},
154 {"__exit__", DisableTorchFunction__exit, METH_VARARGS, nullptr},
155 {nullptr, nullptr, 0, nullptr}};
156
157 PyTypeObject DisableTorchFunctionType = {
158 PyVarObject_HEAD_INIT(
159 nullptr,
160 0) "torch._C.DisableTorchFunction", /* tp_name */
161 sizeof(DisableTorchFunction), /* tp_basicsize */
162 0, /* tp_itemsize */
163 nullptr, /* tp_dealloc */
164 0, /* tp_vectorcall_offset */
165 nullptr, /* tp_getattr */
166 nullptr, /* tp_setattr */
167 nullptr, /* tp_reserved */
168 nullptr, /* tp_repr */
169 nullptr, /* tp_as_number */
170 nullptr, /* tp_as_sequence */
171 nullptr, /* tp_as_mapping */
172 nullptr, /* tp_hash */
173 nullptr, /* tp_call */
174 nullptr, /* tp_str */
175 nullptr, /* tp_getattro */
176 nullptr, /* tp_setattro */
177 nullptr, /* tp_as_buffer */
178 Py_TPFLAGS_DEFAULT, /* tp_flags */
179 nullptr, /* tp_doc */
180 nullptr, /* tp_traverse */
181 nullptr, /* tp_clear */
182 nullptr, /* tp_richcompare */
183 0, /* tp_weaklistoffset */
184 nullptr, /* tp_iter */
185 nullptr, /* tp_iternext */
186 DisableTorchFunction_methods, /* tp_methods */
187 nullptr, /* tp_members */
188 nullptr, /* tp_getset */
189 nullptr, /* tp_base */
190 nullptr, /* tp_dict */
191 nullptr, /* tp_descr_get */
192 nullptr, /* tp_descr_set */
193 0, /* tp_dictoffset */
194 nullptr, /* tp_init */
195 PyType_GenericAlloc, /* tp_alloc */
196 PyType_GenericNew, /* tp_new */
197 };
198
THPModule_DisableTorchFunctionType()199 PyObject* THPModule_DisableTorchFunctionType() {
200 if (PyType_Ready(&DisableTorchFunctionType) < 0) {
201 return nullptr;
202 }
203
204 return (PyObject*)(&DisableTorchFunctionType);
205 }
206
THPModule_disable_torch_function(PyObject * self,PyObject * a)207 PyObject* THPModule_disable_torch_function(PyObject* self, PyObject* a) {
208 HANDLE_TH_ERRORS
209 PyObject *func = nullptr, *types = nullptr, *args = nullptr,
210 *kwargs = nullptr;
211 if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) {
212 return nullptr;
213 }
214 py::tuple py_args;
215 if (args == nullptr) {
216 py_args = py::make_tuple();
217 } else if (PyList_Check(args)) {
218 py_args = py::reinterpret_steal<py::tuple>(PyList_AsTuple(args));
219 } else if (PyTuple_Check(args)) {
220 py_args = py::reinterpret_borrow<py::tuple>(args);
221 } else {
222 throw torch::TypeError(
223 "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name);
224 }
225
226 // These are all C-API calls so no exceptions will be raised
227 // and therefore no need for RAII approach to storing
228 // the old value.
229 auto old_value = at::impl::PythonTorchFunctionTLS::get_disabled_state();
230 if (old_value == at::impl::TorchFunctionDisabledState::ENABLED) {
231 at::impl::PythonTorchFunctionTLS::set_disabled_state(
232 at::impl::TorchFunctionDisabledState::SUBCLASSES_DISABLED);
233 }
234 // kwargs can safely be nullptr here.
235 PyObject* result = PyObject_Call(func, py_args.ptr(), kwargs);
236 at::impl::PythonTorchFunctionTLS::set_disabled_state(old_value);
237 return result;
238 END_HANDLE_TH_ERRORS
239 }
240
THPModule_disable_torch_dispatch(PyObject * self,PyObject * a)241 PyObject* THPModule_disable_torch_dispatch(PyObject* self, PyObject* a) {
242 HANDLE_TH_ERRORS
243 PyObject *func = nullptr, *types = nullptr, *args = nullptr,
244 *kwargs = nullptr;
245 if (!PyArg_ParseTuple(a, "OO|OO", &func, &types, &args, &kwargs)) {
246 return nullptr;
247 }
248 py::tuple py_args;
249 if (args == nullptr) {
250 py_args = py::make_tuple();
251 } else if (PyList_Check(args)) {
252 py_args = py::reinterpret_steal<py::tuple>(PyList_AsTuple(args));
253 } else if (PyTuple_Check(args)) {
254 py_args = py::reinterpret_borrow<py::tuple>(args);
255 } else {
256 throw torch::TypeError(
257 "expected List or Tuple (got %s)", Py_TYPE(args)->tp_name);
258 }
259
260 // This implementation is not completely correct. The moral
261 // meaning of this function is that we should do a redispatch
262 // "after" PythonKey, aka a redispatch() call. But we don't have a
263 // dispatcher call here; we have an opaque Python object.
264 //
265 // What we have here is a close approximation: instead of redispatch(), we
266 // just exclude Python and all the keys before it, so that we will go
267 // to the next key after Python. The difference, however, is we are
268 // now PERMANENTLY after Python. We don't think there are any legitimate
269 // cases where we want to go for another round on the entire dispatcher key
270 // set, but if there are, then we will have to do something else here.
271 c10::impl::ExcludeDispatchKeyGuard guard_(
272 // TODO: add constructor for this specifically
273 c10::DispatchKeySet(c10::DispatchKeySet::FULL) -
274 c10::DispatchKeySet(
275 c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::Python)
276 // NB: off by one hazard here, but it works out: python key is not
277 // included in AFTER, so it is included in the negation (and that's
278 // correct: we want to exclude Python key and everything BEFORE it.)
279 );
280 auto r = PyObject_Call(func, py_args.ptr(), kwargs);
281 if (r == nullptr)
282 throw python_error();
283 return r;
284 END_HANDLE_TH_ERRORS
285 }
286
287 // Makes sure that we don't check for __torch_function__ on basic Python types
is_basic_python_type(PyTypeObject * tp)288 static bool is_basic_python_type(PyTypeObject* tp) {
289 return (
290 /* Basic number types */
291 tp == &PyBool_Type ||
292
293 tp == &PyLong_Type || tp == &PyFloat_Type || tp == &PyComplex_Type ||
294
295 /* Basic sequence types */
296 tp == &PyList_Type || tp == &PyTuple_Type || tp == &PyDict_Type ||
297 tp == &PySet_Type || tp == &PyFrozenSet_Type || tp == &PyUnicode_Type ||
298 tp == &PyBytes_Type ||
299
300 /* other builtins */
301 tp == &PySlice_Type || tp == Py_TYPE(Py_None) ||
302 tp == Py_TYPE(Py_Ellipsis) || tp == Py_TYPE(Py_NotImplemented) ||
303
304 PyModule_Check(tp) ||
305 /* sentinel to swallow trailing || */
306 false);
307 }
308
has_torch_function_attr(PyObject * obj)309 inline bool has_torch_function_attr(PyObject* obj) {
310 auto attr = PyObject_FastGetAttrString(obj, "__torch_function__");
311 return (
312 attr.ptr() != nullptr && attr.ptr() != torch::disabled_torch_function);
313 }
314
315 namespace torch {
check_has_torch_function(PyObject * obj,bool ignore_mode)316 auto check_has_torch_function(PyObject* obj, bool ignore_mode) -> bool {
317 if (!ignore_mode && at::impl::torch_function_mode_enabled())
318 return true;
319 PyTypeObject* tp = Py_TYPE(obj);
320 return (
321 !THPVariable_CheckTypeExact(tp) && !is_basic_python_type(tp) &&
322 torch::torch_function_enabled() && has_torch_function_attr(obj));
323 }
324 } // namespace torch
325
sequence_has_torch_function(PyObject * args)326 inline bool sequence_has_torch_function(PyObject* args) {
327 // NOLINTNEXTLINE(bugprone-branch-clone)
328 Py_ssize_t nargs = PySequence_Fast_GET_SIZE(args);
329 for (Py_ssize_t i = 0; i < nargs; i++) {
330 PyObject* obj = PySequence_Fast_GET_ITEM(args, i);
331 if (torch::check_has_torch_function(obj)) {
332 return true;
333 }
334 }
335 return false;
336 }
337
array_has_torch_function(PyObject * const * args,Py_ssize_t nargs)338 inline bool array_has_torch_function(PyObject* const* args, Py_ssize_t nargs) {
339 for (Py_ssize_t i = 0; i < nargs; i++) {
340 if (torch::check_has_torch_function(args[i])) {
341 return true;
342 }
343 }
344 return false;
345 }
346
THPModule_has_torch_function(PyObject *,PyObject * arg)347 PyObject* THPModule_has_torch_function(PyObject*, PyObject* arg) {
348 bool result; // NOLINT(cppcoreguidelines-init-variables)
349 if (PyTuple_CheckExact(arg) || PyList_CheckExact(arg)) {
350 // Fast path:
351 // If we know that we have a tuple or list, we can skip an INCREF and
352 // DECREF from PySequence_Fast. Core functions will always follow this
353 // convention (almost always tuples), and it shaves ~3.5% off the cost of
354 // the check.
355 result = sequence_has_torch_function(arg);
356 } else {
357 auto args = py::reinterpret_steal<py::object>(
358 PySequence_Fast(arg, "expected a sequence"));
359 if (!args) {
360 return nullptr;
361 }
362 result = sequence_has_torch_function(args.ptr());
363 }
364
365 if (result) {
366 Py_RETURN_TRUE;
367 }
368 Py_RETURN_FALSE;
369 }
370
THPModule_has_torch_function_unary(PyObject *,PyObject * obj)371 PyObject* THPModule_has_torch_function_unary(PyObject*, PyObject* obj) {
372 // Special case `THPModule_has_torch_function` for the single arg case.
373 if (torch::check_has_torch_function(obj)) {
374 Py_RETURN_TRUE;
375 }
376 Py_RETURN_FALSE;
377 }
378
THPModule_has_torch_function_variadic(PyObject *,PyObject * const * args,Py_ssize_t nargs)379 PyObject* THPModule_has_torch_function_variadic(
380 PyObject*,
381 PyObject* const* args,
382 Py_ssize_t nargs) {
383 if (array_has_torch_function(args, nargs)) {
384 Py_RETURN_TRUE;
385 }
386 Py_RETURN_FALSE;
387 }
388