xref: /aosp_15_r20/external/pytorch/torch/csrc/Module.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/DeviceAccelerator.h>
2 #include <fmt/core.h>
3 #include <sys/types.h>
4 #include <torch/csrc/python_headers.h>
5 #include <optional>
6 
7 #ifndef _MSC_VER
8 #include <sys/socket.h>
9 #endif
10 
11 #include <ATen/ATen.h>
12 #include <ATen/BlasBackend.h>
13 #include <ATen/CachedTensorUtils.h>
14 #include <ATen/DLConvertor.h>
15 #include <ATen/ExpandUtils.h>
16 #include <ATen/LegacyVmapMode.h>
17 #include <ATen/LinalgBackend.h>
18 #include <ATen/Parallel.h>
19 #include <ATen/Utils.h>
20 #include <ATen/core/Vitals.h>
21 #include <ATen/detail/AcceleratorHooksInterface.h>
22 #include <ATen/dlpack.h>
23 #include <ATen/native/ConvUtils.h>
24 #include <ATen/native/ForeachUtils.h>
25 #include <ATen/native/Normalization.h>
26 #include <c10/core/Device.h>
27 #include <c10/core/DispatchKeySet.h>
28 #include <c10/util/AbortHandler.h>
29 #include <c10/util/Backtrace.h>
30 #include <c10/util/Logging.h>
31 #include <c10/util/irange.h>
32 #include <c10/util/thread_name.h>
33 #include <libshm.h>
34 #include <pybind11/pybind11.h>
35 #include <pybind11/stl.h>
36 #include <torch/csrc/THConcat.h>
37 #include <torch/csrc/utils/pybind.h>
38 #include <cstdlib>
39 #include <iostream>
40 #include <unordered_map>
41 
42 #include <ATen/ThreadLocalPythonObjects.h>
43 #include <torch/csrc/DataLoader.h>
44 #include <torch/csrc/Device.h>
45 #include <torch/csrc/Dtype.h>
46 #include <torch/csrc/DynamicTypes.h>
47 #include <torch/csrc/Event.h>
48 #include <torch/csrc/Generator.h>
49 #include <torch/csrc/Layout.h>
50 #include <torch/csrc/MemoryFormat.h>
51 #include <torch/csrc/QScheme.h>
52 #include <torch/csrc/Stream.h>
53 #include <torch/csrc/THP.h>
54 #include <torch/csrc/TypeInfo.h>
55 #include <torch/csrc/api/include/torch/python/init.h>
56 #include <torch/csrc/autograd/generated/python_return_types.h>
57 #include <torch/csrc/autograd/python_cpp_function.h>
58 #include <torch/csrc/autograd/python_enum_tag.h>
59 #include <torch/csrc/autograd/python_fft_functions.h>
60 #include <torch/csrc/autograd/python_function.h>
61 #include <torch/csrc/autograd/python_legacy_variable.h>
62 #include <torch/csrc/autograd/python_linalg_functions.h>
63 #include <torch/csrc/autograd/python_nested_functions.h>
64 #include <torch/csrc/autograd/python_nn_functions.h>
65 #include <torch/csrc/autograd/python_sparse_functions.h>
66 #include <torch/csrc/autograd/python_special_functions.h>
67 #include <torch/csrc/autograd/python_variable.h>
68 #include <torch/csrc/cpu/Module.h>
69 #include <torch/csrc/dynamo/init.h>
70 #include <torch/csrc/functorch/init.h>
71 #include <torch/csrc/fx/node.h>
72 #include <torch/csrc/inductor/aoti_runner/pybind.h>
73 #include <torch/csrc/instruction_counter/Module.h>
74 #include <torch/csrc/jit/python/init.h>
75 #include <torch/csrc/jit/python/python_ir.h>
76 #include <torch/csrc/jit/python/python_tracer.h>
77 #include <torch/csrc/jit/serialization/pickler.h>
78 #include <torch/csrc/lazy/python/init.h>
79 #include <torch/csrc/monitor/python_init.h>
80 #include <torch/csrc/mps/Module.h>
81 #include <torch/csrc/mtia/Module.h>
82 #include <torch/csrc/multiprocessing/init.h>
83 #include <torch/csrc/onnx/init.h>
84 #include <torch/csrc/profiler/python/init.h>
85 #include <torch/csrc/tensor/python_tensor.h>
86 #include <torch/csrc/utils/disable_torch_function.h>
87 #include <torch/csrc/utils/init.h>
88 #include <torch/csrc/utils/pycfunction_helpers.h>
89 #include <torch/csrc/utils/python_arg_parser.h>
90 #include <torch/csrc/utils/python_compat.h>
91 #include <torch/csrc/utils/python_dispatch.h>
92 #include <torch/csrc/utils/python_strings.h>
93 #include <torch/csrc/utils/tensor_dtypes.h>
94 #include <torch/csrc/utils/tensor_layouts.h>
95 #include <torch/csrc/utils/tensor_memoryformats.h>
96 #include <torch/csrc/utils/tensor_new.h>
97 #include <torch/csrc/utils/tensor_numpy.h>
98 #include <torch/csrc/utils/tensor_qschemes.h>
99 #include <torch/csrc/utils/verbose.h>
100 
101 #include <ATen/native/transformers/sdp_utils_cpp.h>
102 #include <torch/csrc/profiler/combined_traceback.h>
103 #include <sstream>
104 
105 #ifdef USE_CUDA
106 #include <ATen/cuda/CUDAConfig.h>
107 #include <ATen/native/transformers/cuda/sdp_utils.h>
108 #ifdef __HIP_PLATFORM_AMD__
109 #include <ATen/native/cudnn/hip/BatchNorm.h>
110 #else
111 #include <ATen/native/cudnn/BatchNorm.h>
112 #endif
113 #endif
114 
115 #ifdef USE_DISTRIBUTED
116 #ifdef USE_C10D
117 #include <torch/csrc/distributed/autograd/python_autograd.h>
118 #include <torch/csrc/distributed/c10d/c10d.h>
119 #include <torch/csrc/distributed/rpc/rpc.h>
120 #include <torch/csrc/distributed/rpc/testing/testing.h>
121 #endif
122 #endif
123 
124 #if defined(USE_VALGRIND)
125 #include <callgrind.h>
126 #endif
127 
128 namespace py = pybind11;
129 
130 PyObject* module;
131 
132 THPGenerator* THPDefaultCPUGenerator = nullptr;
133 
134 ////////////////////////////////////////////////////////////////////////////////
135 ////////////////////////////////////////////////////////////////////////////////
136 
THPModule_initNames(PyObject * self,PyObject * arg)137 static PyObject* THPModule_initNames(PyObject* self, PyObject* arg) {
138   HANDLE_TH_ERRORS
139   static std::vector<std::string> names;
140 
141   THPObjectPtr types(PySequence_Fast(arg, "expected a sequence"));
142   if (!types)
143     return nullptr;
144 
145   // NOLINTNEXTLINE(bugprone-branch-clone)
146   auto num_classes = PySequence_Fast_GET_SIZE(types.get());
147   names.reserve(names.size() + num_classes);
148   for (Py_ssize_t i = 0; i < num_classes; i++) {
149     PyObject* obj = PySequence_Fast_GET_ITEM(types.get(), i);
150     TORCH_CHECK(PyType_Check(obj), "expected a PyTypeObject");
151     PyTypeObject* type = (PyTypeObject*)obj;
152 
153     THPObjectPtr module_name(PyObject_GetAttrString(obj, "__module__"));
154     if (!module_name)
155       return nullptr;
156     TORCH_CHECK(
157         THPUtils_checkString(module_name.get()),
158         "expected __module__ to be a string");
159     std::string name = THPUtils_unpackString(module_name.get());
160     names.emplace_back(name + "." + type->tp_name);
161     type->tp_name = names.back().c_str();
162   }
163   Py_RETURN_NONE;
164   END_HANDLE_TH_ERRORS
165 }
166 //
167 // Callback for python part. Used for additional initialization of python
168 // classes
THPModule_initExtension(PyObject * _unused,PyObject * shm_manager_path)169 static PyObject* THPModule_initExtension(
170     PyObject* _unused,
171     PyObject* shm_manager_path) {
172   HANDLE_TH_ERRORS
173 #if !defined(FBCODE_CAFFE2) && !defined(__aarch64__)
174   if (torch::get_cpp_stacktraces_enabled()) {
175     c10::SetStackTraceFetcher([]() -> std::string {
176       auto tb = torch::CapturedTraceback::gather(false, false, true);
177       if (torch::get_symbolize_mode() == torch::unwind::Mode::addr2line) {
178         LOG(WARNING)
179             << "symbolizing C++ stack trace for exception; if this hangs, rerun with TORCH_DISABLE_ADDR2LINE=1..."
180             << std::endl;
181       }
182       auto s_tbs = torch::symbolize({tb.get()});
183       std::stringstream oss;
184       oss << "C++ CapturedTraceback:" << std::endl;
185       const auto& s_tb = s_tbs.tracebacks.at(0);
186       for (auto idx : c10::irange(s_tb.size())) {
187         // Skip the first few frames:
188         //  #1 torch::CapturedTraceback::gather(bool, bool, bool)
189         //  #2 THPModule_initExtension
190         //  #3 THPModule_initExtension(_object*, _object*)::{lambda()#1}
191         if (idx <= 3) {
192           continue;
193         }
194         auto frame_id = s_tb[idx];
195         const auto& frame = s_tbs.all_frames.at(frame_id);
196         oss << "#" << idx << " " << frame.funcname << " from " << frame.filename
197             << ":" << frame.lineno << std::endl;
198       }
199       return oss.str();
200     });
201   }
202 #endif
203   if (!THPUtils_checkString(shm_manager_path)) {
204     THPUtils_setError(
205         "initialization error - expected bytes/string object as shm_manager_path!");
206     return nullptr;
207   }
208   torch::utils::initializeLayouts();
209   torch::utils::initializeMemoryFormats();
210   torch::utils::initializeQSchemes();
211   torch::utils::initializeDtypes();
212   torch::tensors::initialize_python_bindings();
213   std::string path = THPUtils_unpackString(shm_manager_path);
214   libshm_init(path.c_str());
215 
216   auto module = THPObjectPtr(PyImport_ImportModule("torch"));
217   if (!module)
218     throw python_error();
219 
220   THPStorage_postInit(module);
221   THPAutograd_initFunctions();
222   Py_RETURN_NONE;
223   END_HANDLE_TH_ERRORS
224 }
225 
226 // The idea behind these two functions is to make it easy to test if we are
227 // built with ASAN: they're designed not to crash if ASAN is not enabled, but
228 // to trigger ASAN if it is enabled.  This lets us run a "canary" tests which
229 // checks if our build environment is misconfigured.
230 
THPModule_crashIfCsrcASAN(PyObject * module,PyObject * arg)231 static PyObject* THPModule_crashIfCsrcASAN(PyObject* module, PyObject* arg) {
232   HANDLE_TH_ERRORS
233   TORCH_CHECK(
234       THPUtils_checkLong(arg),
235       "crash_if_csrc_asan expects an int, but got ",
236       THPUtils_typename(arg));
237   // NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays, modernize-avoid-c-arrays)
238   volatile char x[3];
239   x[THPUtils_unpackInt(arg)] = 0;
240   // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
241   return THPUtils_packInt32(x[0]);
242   END_HANDLE_TH_ERRORS
243 }
244 
THPModule_crashIfCsrcUBSAN(PyObject * module,PyObject * arg)245 static PyObject* THPModule_crashIfCsrcUBSAN(PyObject* module, PyObject* arg) {
246   HANDLE_TH_ERRORS
247   TORCH_CHECK(
248       THPUtils_checkLong(arg),
249       "crash_if_csrc_ubsan expects an int, but got ",
250       THPUtils_typename(arg));
251   int32_t x = THPUtils_unpackInt(arg);
252   double y = 1.0 / x;
253   return THPUtils_packInt32((int)y);
254   END_HANDLE_TH_ERRORS
255 }
256 
THPModule_crashIfvptrUBSAN(PyObject * module,PyObject * noarg)257 static PyObject* THPModule_crashIfvptrUBSAN(PyObject* module, PyObject* noarg) {
258   // This code should work perfectly fine, as vtables are identical for Foo and
259   // Baz unless rtti and ubsan are enabled
260   struct Foo {
261     virtual int bar() = 0;
262     virtual ~Foo() = default;
263   };
264   struct Baz {
265     virtual int bar() {
266       return 17;
267     }
268     virtual ~Baz() = default;
269   };
270   Baz x{};
271   auto y = static_cast<Foo*>(static_cast<void*>(&x));
272   auto rc = y->bar();
273   return THPUtils_packInt32(rc);
274 }
275 
THPModule_crashIfATenASAN(PyObject * module,PyObject * arg)276 static PyObject* THPModule_crashIfATenASAN(PyObject* module, PyObject* arg) {
277   HANDLE_TH_ERRORS
278   TORCH_CHECK(
279       THPUtils_checkLong(arg),
280       "crash_if_aten_asan expects an int, "
281       "but got ",
282       THPUtils_typename(arg));
283   return THPUtils_packInt32(at::_crash_if_asan(THPUtils_unpackInt(arg)));
284   END_HANDLE_TH_ERRORS
285 }
286 
THPModule_abort(PyObject * module,PyObject * noargs)287 static PyObject* THPModule_abort(PyObject* module, PyObject* noargs) {
288   std::terminate();
289   Py_RETURN_NONE;
290 }
291 
THPModule_crashIfDebugAssertsFail(PyObject * module,PyObject * arg)292 static PyObject* THPModule_crashIfDebugAssertsFail(
293     PyObject* module,
294     PyObject* arg) {
295   HANDLE_TH_ERRORS
296   TORCH_CHECK(
297       THPUtils_checkLong(arg),
298       "crash_if_debug_asserts_fail expects an int, but got ",
299       THPUtils_typename(arg));
300   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
301       THPUtils_unpackInt(arg) != 424242,
302       "Expect anything but 424242 as an input for debug builds");
303   return THPUtils_packInt32(0);
304   END_HANDLE_TH_ERRORS
305 }
306 
THPModule_getNumThreads(PyObject * module,PyObject * noargs)307 static PyObject* THPModule_getNumThreads(PyObject* module, PyObject* noargs) {
308   return THPUtils_packInt32(at::get_num_threads());
309 }
310 
THPModule_setNumThreads(PyObject * module,PyObject * arg)311 static PyObject* THPModule_setNumThreads(PyObject* module, PyObject* arg) {
312   HANDLE_TH_ERRORS
313   TORCH_CHECK(
314       THPUtils_checkLong(arg),
315       "set_num_threads expects an int, but got ",
316       THPUtils_typename(arg));
317   int nthreads = (int)THPUtils_unpackLong(arg);
318   TORCH_CHECK(nthreads > 0, "set_num_threads expects a positive integer");
319   at::set_num_threads(nthreads);
320   Py_RETURN_NONE;
321   END_HANDLE_TH_ERRORS
322 }
323 
THPModule_getNumInteropThreads(PyObject * module,PyObject * noargs)324 static PyObject* THPModule_getNumInteropThreads(
325     PyObject* module,
326     PyObject* noargs) {
327   return THPUtils_packInt32(at::get_num_interop_threads());
328 }
329 
THPModule_setNumInteropThreads(PyObject * module,PyObject * arg)330 static PyObject* THPModule_setNumInteropThreads(
331     PyObject* module,
332     PyObject* arg) {
333   HANDLE_TH_ERRORS
334   TORCH_CHECK(
335       THPUtils_checkLong(arg),
336       "set_num_interop_threads expects an int, "
337       "but got ",
338       THPUtils_typename(arg));
339   int nthreads = (int)THPUtils_unpackLong(arg);
340   TORCH_CHECK(
341       nthreads > 0, "set_num_interop_threads expects a positive integer");
342   at::set_num_interop_threads(nthreads);
343   Py_RETURN_NONE;
344   END_HANDLE_TH_ERRORS
345 }
346 
THPModule_setDefaultTensorType(PyObject * _unused,PyObject * type)347 PyObject* THPModule_setDefaultTensorType(PyObject* _unused, PyObject* type) {
348   HANDLE_TH_ERRORS
349   torch::tensors::py_set_default_tensor_type(type);
350   Py_RETURN_NONE;
351   END_HANDLE_TH_ERRORS
352 }
353 
THPModule_setDefaultDtype(PyObject * _unused,PyObject * dtype)354 PyObject* THPModule_setDefaultDtype(PyObject* _unused, PyObject* dtype) {
355   HANDLE_TH_ERRORS
356   torch::tensors::py_set_default_dtype(dtype);
357   Py_RETURN_NONE;
358   END_HANDLE_TH_ERRORS
359 }
360 
THPModule_swap_tensor_impl(PyObject * _unused,PyObject * args)361 PyObject* THPModule_swap_tensor_impl(PyObject* _unused, PyObject* args) {
362   HANDLE_TH_ERRORS
363   PyObject* a_ = nullptr;
364   PyObject* b_ = nullptr;
365   if (!PyArg_ParseTuple(args, "OO", &a_, &b_)) {
366     return nullptr;
367   }
368 
369   // Ensure we have Tensors
370   TORCH_CHECK(THPVariable_Check(a_));
371   TORCH_CHECK(THPVariable_Check(b_));
372 
373   THPVariable* a = reinterpret_cast<THPVariable*>(a_);
374   THPVariable* b = reinterpret_cast<THPVariable*>(b_);
375 
376   // weak_use_count() adds 1 if use_count is non-zero
377   TORCH_CHECK(
378       a->cdata->weak_use_count() == 1,
379       "Expected no weakrefs to t1's Tensor object but got  ",
380       a->cdata->weak_use_count() - 1);
381   TORCH_CHECK(
382       b->cdata->weak_use_count() == 1,
383       "Expected no weakrefs to t2's Tensor object but got  ",
384       b->cdata->weak_use_count() - 1);
385 
386   // Swap the Tensor Impl
387   c10::MaybeOwned<at::Tensor> tmp = a->cdata;
388 
389   // The TensorImpls contain PyObjectSlots that have a reference to the PyObject
390   // associated with the TensorImpl. Swap this field as well.
391   std::optional<PyObject*> mb_obj_a =
392       a->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
393           getPyInterpreter(), /*ignore_hermetic_tls=*/false);
394   std::optional<PyObject*> mb_obj_b =
395       b->cdata->unsafeGetTensorImpl()->pyobj_slot()->check_pyobj(
396           getPyInterpreter(), /*ignore_hermetic_tls=*/false);
397   TORCH_INTERNAL_ASSERT(
398       mb_obj_a.has_value() && mb_obj_b.has_value(),
399       "Both tensors should have PyObjects tagged by the current python interpreter");
400   TORCH_CHECK(mb_obj_a.value() == a_);
401   TORCH_CHECK(mb_obj_b.value() == b_);
402 
403   a->cdata = b->cdata;
404   b->cdata = tmp;
405 
406   a->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
407       getPyInterpreter(), a_, c10::impl::PyInterpreterStatus::TAGGED_BY_US);
408   b->cdata->unsafeGetTensorImpl()->pyobj_slot()->init_pyobj(
409       getPyInterpreter(), b_, c10::impl::PyInterpreterStatus::TAGGED_BY_US);
410 
411   Py_RETURN_NONE;
412   END_HANDLE_TH_ERRORS
413 }
414 
THPModule_addDocStr(PyObject * _unused,PyObject * args)415 PyObject* THPModule_addDocStr(PyObject* _unused, PyObject* args) {
416   // adds a __doc__ string to a function, similar to numpy's arr_add_docstring
417   static std::vector<std::string> all_docs;
418   PyObject* obj = nullptr;
419   PyObject* doc_obj = nullptr;
420   if (!PyArg_ParseTuple(args, "OO", &obj, &doc_obj)) {
421     return nullptr;
422   }
423 
424   const char* doc_str = "<invalid string>";
425   if (THPUtils_checkString(doc_obj)) {
426     all_docs.push_back(THPUtils_unpackString(doc_obj));
427     doc_str = all_docs.back().c_str();
428   }
429 
430   if (Py_TYPE(obj) == &PyCFunction_Type) {
431     PyCFunctionObject* f = (PyCFunctionObject*)obj;
432     if (f->m_ml->ml_doc) {
433       return PyErr_Format(
434           PyExc_RuntimeError,
435           "function '%s' already has a docstring",
436           f->m_ml->ml_name);
437     }
438     f->m_ml->ml_doc = doc_str;
439   } else if (strcmp(Py_TYPE(obj)->tp_name, "method_descriptor") == 0) {
440     PyMethodDescrObject* m = (PyMethodDescrObject*)obj;
441     if (m->d_method->ml_doc) {
442       return PyErr_Format(
443           PyExc_RuntimeError,
444           "method '%s' already has a docstring",
445           m->d_method->ml_name);
446     }
447     m->d_method->ml_doc = doc_str;
448   } else if (strcmp(Py_TYPE(obj)->tp_name, "getset_descriptor") == 0) {
449     // NOLINTNEXTLINE(cppcoreguidelines-pro-type-cstyle-cast)
450     PyGetSetDescrObject* m = (PyGetSetDescrObject*)obj;
451     if (m->d_getset->doc) {
452       return PyErr_Format(
453           PyExc_RuntimeError,
454           "attribute '%s' already has a docstring",
455           m->d_getset->name);
456     }
457     m->d_getset->doc = doc_str;
458   } else if (Py_TYPE(obj) == &PyType_Type) {
459     PyTypeObject* t = (PyTypeObject*)obj;
460     if (t->tp_doc) {
461       return PyErr_Format(
462           PyExc_RuntimeError, "Type '%s' already has a docstring", t->tp_name);
463     }
464     t->tp_doc = doc_str;
465   } else {
466     return PyErr_Format(
467         PyExc_TypeError,
468         "don't know how to add docstring to type '%s'",
469         Py_TYPE(obj)->tp_name);
470   }
471 
472   Py_INCREF(obj);
473   return obj;
474 }
475 
THPModule_inferSize(PyObject * _unused,PyObject * args)476 PyObject* THPModule_inferSize(PyObject* _unused, PyObject* args) {
477   HANDLE_TH_ERRORS
478   Py_ssize_t num_args = args ? (Py_ssize_t)PyTuple_Size(args) : 0;
479   TORCH_CHECK(num_args == 2, "expected exactly 2 arguments");
480   PyObject* arg1 = PyTuple_GET_ITEM(args, 0);
481   TORCH_CHECK(THPSize_Check(arg1), "expected a torch.Size as argument 1");
482   PyObject* arg2 = PyTuple_GET_ITEM(args, 1);
483   TORCH_CHECK(THPSize_Check(arg2), "expected a torch.Size as argument 2");
484 
485   auto size1 = THPUtils_unpackLongs(arg1);
486   auto size2 = THPUtils_unpackLongs(arg2);
487   auto sizes = at::infer_size(size1, size2);
488   return THPSize_NewFromSizes(static_cast<int64_t>(sizes.size()), sizes.data());
489   END_HANDLE_TH_ERRORS
490 }
491 
THPModule_setBackcompatBroadcastWarn(PyObject * module,PyObject * arg)492 static PyObject* THPModule_setBackcompatBroadcastWarn(
493     PyObject* module,
494     PyObject* arg) {
495   HANDLE_TH_ERRORS
496   TORCH_CHECK(
497       PyBool_Check(arg),
498       "set_backcompat_broadcast_warn expects a bool, "
499       "but got ",
500       THPUtils_typename(arg));
501   setBackCompatBroadcastWarn(arg == Py_True);
502   Py_RETURN_NONE;
503   END_HANDLE_TH_ERRORS
504 }
505 
THPModule_getBackcompatBroadcastWarn(PyObject * module,PyObject * noargs)506 static PyObject* THPModule_getBackcompatBroadcastWarn(
507     PyObject* module,
508     PyObject* noargs) {
509   if (getBackCompatBroadcastWarn())
510     Py_RETURN_TRUE;
511   else
512     Py_RETURN_FALSE;
513 }
514 
THPModule_setBackcompatKeepdimWarn(PyObject * module,PyObject * arg)515 static PyObject* THPModule_setBackcompatKeepdimWarn(
516     PyObject* module,
517     PyObject* arg) {
518   HANDLE_TH_ERRORS
519   TORCH_CHECK(
520       PyBool_Check(arg),
521       "set_backcompat_keepdim_warn expects a bool, "
522       "but got ",
523       THPUtils_typename(arg));
524   setBackCompatKeepdimWarn(arg == Py_True);
525   Py_RETURN_NONE;
526   END_HANDLE_TH_ERRORS
527 }
528 
THPModule_getBackcompatKeepdimWarn(PyObject * module,PyObject * noargs)529 static PyObject* THPModule_getBackcompatKeepdimWarn(
530     PyObject* module,
531     PyObject* noargs) {
532   if (getBackCompatKeepdimWarn())
533     Py_RETURN_TRUE;
534   else
535     Py_RETURN_FALSE;
536 }
537 
THPModule_hasDistributed(PyObject * _unused,PyObject * noargs)538 PyObject* THPModule_hasDistributed(PyObject* _unused, PyObject* noargs) {
539 #ifdef USE_DISTRIBUTED
540   Py_RETURN_TRUE;
541 #else
542   Py_RETURN_FALSE;
543 #endif
544 }
545 
THPModule_showConfig(PyObject * module,PyObject * noargs)546 static PyObject* THPModule_showConfig(PyObject* module, PyObject* noargs) {
547   HANDLE_TH_ERRORS
548   return THPUtils_packString(at::show_config());
549   END_HANDLE_TH_ERRORS
550 }
551 
THPModule_cxxFlags(PyObject * module,PyObject * noargs)552 static PyObject* THPModule_cxxFlags(PyObject* module, PyObject* noargs) {
553   HANDLE_TH_ERRORS
554   return THPUtils_packString(at::get_cxx_flags());
555   END_HANDLE_TH_ERRORS
556 }
557 
THPModule_parallelInfo(PyObject * module,PyObject * noargs)558 static PyObject* THPModule_parallelInfo(PyObject* module, PyObject* noargs) {
559   HANDLE_TH_ERRORS
560   return THPUtils_packString(at::get_parallel_info());
561   END_HANDLE_TH_ERRORS
562 }
563 
THPModule_getCpuCapability(PyObject * module,PyObject * noargs)564 static PyObject* THPModule_getCpuCapability(
565     PyObject* module,
566     PyObject* noargs) {
567   HANDLE_TH_ERRORS
568   return THPUtils_packString(at::get_cpu_capability());
569   END_HANDLE_TH_ERRORS
570 }
571 
DLPack_Capsule_Destructor(PyObject * data)572 void DLPack_Capsule_Destructor(PyObject* data) {
573   if (C10_LIKELY(!PyCapsule_IsValid(data, "dltensor"))) {
574     // early out, see DLPack spec: if a consuming library sets the capsule
575     // name to something else, they own it and we don't need to do anything
576     return;
577   }
578   HANDLE_TH_ERRORS
579   // Causes overheads for validity checks again, but this case is rare
580   // since consuming libraries should rename the capsule according to spec.
581   // Note that this cannot set a python error (we checked validity above),
582   // so we don't need to handle python error state here.
583   DLManagedTensor* dlMTensor =
584       (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
585   // the dlMTensor has not been consumed, call deleter ourselves.
586   // DLPack spec mentions that deleter may be NULL, but deleter from
587   // `at::toDLPack` is never NULL, so no need for an additional check here.
588   dlMTensor->deleter(dlMTensor);
589   END_HANDLE_TH_ERRORS_RET()
590 }
591 
THPModule_toDLPack(PyObject * _unused,PyObject * data)592 PyObject* THPModule_toDLPack(PyObject* _unused, PyObject* data) {
593   HANDLE_TH_ERRORS
594   TORCH_CHECK(THPVariable_Check(data), "data must be a Tensor");
595   DLManagedTensor* dlMTensor = at::toDLPack(THPVariable_Unpack(data));
596   return PyCapsule_New(dlMTensor, "dltensor", DLPack_Capsule_Destructor);
597   END_HANDLE_TH_ERRORS
598 }
599 
THPModule_fromDLPack(PyObject * _unused,PyObject * data)600 PyObject* THPModule_fromDLPack(PyObject* _unused, PyObject* data) {
601   using namespace torch::autograd;
602   HANDLE_TH_ERRORS
603   auto tensor = torch::utils::tensor_fromDLPack(data);
604   return THPVariable_Wrap(tensor);
605   END_HANDLE_TH_ERRORS
606 }
607 
THModule_getCppBacktrace(PyObject * _unused,PyObject * args)608 PyObject* THModule_getCppBacktrace(PyObject* _unused, PyObject* args) {
609   HANDLE_TH_ERRORS
610   size_t frames_to_skip = 0;
611   size_t maximum_number_of_frames = 0;
612   if (!PyArg_ParseTuple(
613           args, "LL", &frames_to_skip, &maximum_number_of_frames)) {
614     return nullptr;
615   }
616   return THPUtils_packString(
617       c10::get_backtrace(frames_to_skip, maximum_number_of_frames, true));
618   END_HANDLE_TH_ERRORS
619 }
620 
THModule_rename_privateuse1_backend(PyObject * _unused,PyObject * arg)621 static PyObject* THModule_rename_privateuse1_backend(
622     PyObject* _unused,
623     PyObject* arg) {
624   HANDLE_TH_ERRORS
625   TORCH_CHECK(
626       THPUtils_checkString(arg),
627       "_rename_privateuse1_backend expects a str, but got ",
628       THPUtils_typename(arg));
629   const std::string backend_name = THPUtils_unpackString(arg);
630   c10::register_privateuse1_backend(backend_name);
631   Py_RETURN_NONE;
632   END_HANDLE_TH_ERRORS
633 }
634 
THModule_get_privateuse1_backend_name(PyObject * _unused,PyObject * arg)635 static PyObject* THModule_get_privateuse1_backend_name(
636     PyObject* _unused,
637     PyObject* arg) {
638   HANDLE_TH_ERRORS
639   return THPUtils_packString(c10::get_privateuse1_backend());
640   END_HANDLE_TH_ERRORS
641 }
642 
THPModule_setAllowTF32CuDNN(PyObject * _unused,PyObject * arg)643 PyObject* THPModule_setAllowTF32CuDNN(PyObject* _unused, PyObject* arg) {
644   HANDLE_TH_ERRORS
645   TORCH_CHECK(
646       PyBool_Check(arg),
647       "set_allow_tf32_cublas expects a bool, "
648       "but got ",
649       THPUtils_typename(arg));
650   at::globalContext().setAllowTF32CuDNN(arg == Py_True);
651   Py_RETURN_NONE;
652   END_HANDLE_TH_ERRORS
653 }
654 
THPModule_allowTF32CuDNN(PyObject * _unused,PyObject * noargs)655 PyObject* THPModule_allowTF32CuDNN(PyObject* _unused, PyObject* noargs) {
656   if (at::globalContext().allowTF32CuDNN())
657     Py_RETURN_TRUE;
658   else
659     Py_RETURN_FALSE;
660 }
661 
THPModule_setFloat32MatmulPrecision(PyObject * _unused,PyObject * arg)662 PyObject* THPModule_setFloat32MatmulPrecision(
663     PyObject* _unused,
664     PyObject* arg) {
665   HANDLE_TH_ERRORS
666   TORCH_CHECK(
667       THPUtils_checkString(arg),
668       "set_float32_matmul_precision expects a str, "
669       "but got ",
670       THPUtils_typename(arg));
671   std::string s = THPUtils_unpackString(arg);
672   at::globalContext().setFloat32MatmulPrecision(s);
673   Py_RETURN_NONE;
674   END_HANDLE_TH_ERRORS
675 }
676 
THPModule_float32MatmulPrecision(PyObject * _unused,PyObject * noargs)677 PyObject* THPModule_float32MatmulPrecision(
678     PyObject* _unused,
679     PyObject* noargs) {
680   std::string s = "highest";
681   auto p = at::globalContext().float32MatmulPrecision();
682   if (p == at::Float32MatmulPrecision::HIGH) {
683     s = "high";
684   } else if (p == at::Float32MatmulPrecision::MEDIUM) {
685     s = "medium";
686   }
687   return THPUtils_packString(s);
688 }
THPModule_setSDPUseFlash(PyObject * _unused,PyObject * arg)689 PyObject* THPModule_setSDPUseFlash(PyObject* _unused, PyObject* arg) {
690   HANDLE_TH_ERRORS
691   TORCH_CHECK(
692       PyBool_Check(arg),
693       "set_sdp_use_math expects a bool, "
694       "but got ",
695       THPUtils_typename(arg));
696   at::globalContext().setSDPUseFlash(arg == Py_True);
697   Py_RETURN_NONE;
698   END_HANDLE_TH_ERRORS
699 }
THPModule_userEnabledFlashSDP(PyObject * _unused,PyObject * noargs)700 PyObject* THPModule_userEnabledFlashSDP(PyObject* _unused, PyObject* noargs) {
701   if (at::globalContext().userEnabledFlashSDP())
702     Py_RETURN_TRUE;
703   else
704     Py_RETURN_FALSE;
705 }
THPModule_setSDPUseMemEfficient(PyObject * _unused,PyObject * arg)706 PyObject* THPModule_setSDPUseMemEfficient(PyObject* _unused, PyObject* arg) {
707   HANDLE_TH_ERRORS
708   TORCH_CHECK(
709       PyBool_Check(arg),
710       "set_sdp_use_math expects a bool, "
711       "but got ",
712       THPUtils_typename(arg));
713   at::globalContext().setSDPUseMemEfficient(arg == Py_True);
714   Py_RETURN_NONE;
715   END_HANDLE_TH_ERRORS
716 }
userEnabledMemEfficientSDP(PyObject * _unused,PyObject * noargs)717 PyObject* userEnabledMemEfficientSDP(PyObject* _unused, PyObject* noargs) {
718   if (at::globalContext().userEnabledMemEfficientSDP())
719     Py_RETURN_TRUE;
720   else
721     Py_RETURN_FALSE;
722 }
THPModule_setSDPUseMath(PyObject * _unused,PyObject * arg)723 PyObject* THPModule_setSDPUseMath(PyObject* _unused, PyObject* arg) {
724   HANDLE_TH_ERRORS
725   TORCH_CHECK(
726       PyBool_Check(arg),
727       "set_sdp_use_math expects a bool, "
728       "but got ",
729       THPUtils_typename(arg));
730   at::globalContext().setSDPUseMath(arg == Py_True);
731   Py_RETURN_NONE;
732   END_HANDLE_TH_ERRORS
733 }
THPModule_userEnabledMathSDP(PyObject * _unused,PyObject * noargs)734 PyObject* THPModule_userEnabledMathSDP(PyObject* _unused, PyObject* noargs) {
735   if (at::globalContext().userEnabledMathSDP())
736     Py_RETURN_TRUE;
737   else
738     Py_RETURN_FALSE;
739 }
THPModule_setAllowFP16BF16ReductionMathSDP(PyObject * _unused,PyObject * arg)740 PyObject* THPModule_setAllowFP16BF16ReductionMathSDP(
741     PyObject* _unused,
742     PyObject* arg) {
743   HANDLE_TH_ERRORS
744   TORCH_CHECK(
745       PyBool_Check(arg),
746       "set_sdp_use_math expects a bool, "
747       "but got ",
748       THPUtils_typename(arg));
749   at::globalContext().setAllowFP16BF16ReductionMathSDP(arg == Py_True);
750   Py_RETURN_NONE;
751   END_HANDLE_TH_ERRORS
752 }
THPModule_allowFP16BF16ReductionMathSDP(PyObject * _unused,PyObject * noargs)753 PyObject* THPModule_allowFP16BF16ReductionMathSDP(
754     PyObject* _unused,
755     PyObject* noargs) {
756   if (at::globalContext().allowFP16BF16ReductionMathSDP())
757     Py_RETURN_TRUE;
758   else
759     Py_RETURN_FALSE;
760 }
THPModule_setSDPUseOverrideable(PyObject * _unused,PyObject * arg)761 PyObject* THPModule_setSDPUseOverrideable(PyObject* _unused, PyObject* arg) {
762   HANDLE_TH_ERRORS
763   TORCH_CHECK(
764       PyBool_Check(arg),
765       "set_sdp_use_overrideable expects a bool, "
766       "but got ",
767       THPUtils_typename(arg));
768   at::globalContext().setSDPUseOverrideable(arg == Py_True);
769   Py_RETURN_NONE;
770   END_HANDLE_TH_ERRORS
771 }
THPModule_userEnabledOverrideableSDP(PyObject * _unused,PyObject * noargs)772 PyObject* THPModule_userEnabledOverrideableSDP(
773     PyObject* _unused,
774     PyObject* noargs) {
775   if (at::globalContext().userEnabledOverrideableSDP())
776     Py_RETURN_TRUE;
777   else
778     Py_RETURN_FALSE;
779 }
THPModule_setSDPUseCuDNN(PyObject * _unused,PyObject * arg)780 PyObject* THPModule_setSDPUseCuDNN(PyObject* _unused, PyObject* arg) {
781   HANDLE_TH_ERRORS
782   TORCH_CHECK(
783       PyBool_Check(arg),
784       "set_sdp_use_cudnn expects a bool, "
785       "but got %s",
786       THPUtils_typename(arg));
787   at::globalContext().setSDPUseCuDNN(arg == Py_True);
788   Py_RETURN_NONE;
789   END_HANDLE_TH_ERRORS
790 }
THPModule_userEnabledCuDNNSDP(PyObject * _unused,PyObject * noargs)791 PyObject* THPModule_userEnabledCuDNNSDP(PyObject* _unused, PyObject* noargs) {
792   if (at::globalContext().userEnabledCuDNNSDP())
793     Py_RETURN_TRUE;
794   else
795     Py_RETURN_FALSE;
796 }
797 
THPModule_setUserEnabledCuDNN(PyObject * _unused,PyObject * arg)798 PyObject* THPModule_setUserEnabledCuDNN(PyObject* _unused, PyObject* arg) {
799   HANDLE_TH_ERRORS
800   TORCH_CHECK(
801       PyBool_Check(arg),
802       "set_enabled_cudnn expects a bool, "
803       "but got ",
804       THPUtils_typename(arg));
805   at::globalContext().setUserEnabledCuDNN(arg == Py_True);
806   Py_RETURN_NONE;
807   END_HANDLE_TH_ERRORS
808 }
809 
THPModule_userEnabledCuDNN(PyObject * _unused,PyObject * noargs)810 PyObject* THPModule_userEnabledCuDNN(PyObject* _unused, PyObject* noargs) {
811   if (at::globalContext().userEnabledCuDNN())
812     Py_RETURN_TRUE;
813   else
814     Py_RETURN_FALSE;
815 }
816 
THPModule_setUserEnabledMkldnn(PyObject * _unused,PyObject * arg)817 PyObject* THPModule_setUserEnabledMkldnn(PyObject* _unused, PyObject* arg) {
818   HANDLE_TH_ERRORS
819   TORCH_CHECK(
820       PyBool_Check(arg),
821       "set_enabled_mkldnn expects a bool, "
822       "but got ",
823       THPUtils_typename(arg));
824   at::globalContext().setUserEnabledMkldnn(arg == Py_True);
825   Py_RETURN_NONE;
826   END_HANDLE_TH_ERRORS
827 }
828 
THPModule_userEnabledMkldnn(PyObject * _unused,PyObject * noargs)829 PyObject* THPModule_userEnabledMkldnn(PyObject* _unused, PyObject* noargs) {
830   if (at::globalContext().userEnabledMkldnn())
831     Py_RETURN_TRUE;
832   else
833     Py_RETURN_FALSE;
834 }
835 
THPModule_setDeterministicCuDNN(PyObject * _unused,PyObject * arg)836 PyObject* THPModule_setDeterministicCuDNN(PyObject* _unused, PyObject* arg) {
837   HANDLE_TH_ERRORS
838   TORCH_CHECK(
839       PyBool_Check(arg),
840       "set_deterministic_cudnn expects a bool, "
841       "but got ",
842       THPUtils_typename(arg));
843   at::globalContext().setDeterministicCuDNN(arg == Py_True);
844   Py_RETURN_NONE;
845   END_HANDLE_TH_ERRORS
846 }
847 
THPModule_deterministicCuDNN(PyObject * _unused,PyObject * noargs)848 PyObject* THPModule_deterministicCuDNN(PyObject* _unused, PyObject* noargs) {
849   if (at::globalContext().deterministicCuDNN())
850     Py_RETURN_TRUE;
851   else
852     Py_RETURN_FALSE;
853 }
854 
THPModule_setDeterministicMkldnn(PyObject * _unused,PyObject * arg)855 PyObject* THPModule_setDeterministicMkldnn(PyObject* _unused, PyObject* arg) {
856   HANDLE_TH_ERRORS
857   TORCH_CHECK(
858       PyBool_Check(arg),
859       "set_deterministic_mkldnn expects a bool, "
860       "but got ",
861       THPUtils_typename(arg));
862   at::globalContext().setDeterministicMkldnn(arg == Py_True);
863   Py_RETURN_NONE;
864   END_HANDLE_TH_ERRORS
865 }
866 
THPModule_deterministicMkldnn(PyObject * _unused,PyObject * noargs)867 PyObject* THPModule_deterministicMkldnn(PyObject* _unused, PyObject* noargs) {
868   if (at::globalContext().deterministicMkldnn())
869     Py_RETURN_TRUE;
870   else
871     Py_RETURN_FALSE;
872 }
873 
THPModule_setDeterministicAlgorithms(PyObject * _unused,PyObject * args,PyObject * kwargs)874 PyObject* THPModule_setDeterministicAlgorithms(
875     PyObject* _unused,
876     PyObject* args,
877     PyObject* kwargs) {
878   HANDLE_TH_ERRORS
879   static torch::PythonArgParser parser(
880       {"_set_deterministic_algorithms(bool mode, *, bool warn_only=False)"});
881   torch::ParsedArgs<2> parsed_args{};
882   auto r = parser.parse(args, kwargs, parsed_args);
883   bool mode = r.toBool(0);
884   bool warn_only = r.toBool(1);
885   at::globalContext().setDeterministicAlgorithms(mode, warn_only);
886   Py_RETURN_NONE;
887   END_HANDLE_TH_ERRORS
888 }
889 
THPModule_deterministicAlgorithms(PyObject * _unused,PyObject * noargs)890 PyObject* THPModule_deterministicAlgorithms(
891     PyObject* _unused,
892     PyObject* noargs) {
893   if (at::globalContext().deterministicAlgorithms()) {
894     Py_RETURN_TRUE;
895   }
896   Py_RETURN_FALSE;
897 }
898 
THPModule_deterministicAlgorithmsWarnOnly(PyObject * _unused,PyObject * noargs)899 PyObject* THPModule_deterministicAlgorithmsWarnOnly(
900     PyObject* _unused,
901     PyObject* noargs) {
902   if (at::globalContext().deterministicAlgorithmsWarnOnly()) {
903     Py_RETURN_TRUE;
904   }
905   Py_RETURN_FALSE;
906 }
907 
THPModule_setDeterministicFillUninitializedMemory(PyObject * _unused,PyObject * arg)908 PyObject* THPModule_setDeterministicFillUninitializedMemory(
909     PyObject* _unused,
910     PyObject* arg) {
911   HANDLE_TH_ERRORS
912   TORCH_CHECK(
913       PyBool_Check(arg), "expected a bool, but got ", THPUtils_typename(arg));
914   at::globalContext().setDeterministicFillUninitializedMemory(arg == Py_True);
915   Py_RETURN_NONE;
916   END_HANDLE_TH_ERRORS
917 }
918 
THPModule_deterministicFillUninitializedMemory(PyObject * _unused,PyObject * noargs)919 PyObject* THPModule_deterministicFillUninitializedMemory(
920     PyObject* _unused,
921     PyObject* noargs) {
922   if (at::globalContext().deterministicFillUninitializedMemory())
923     Py_RETURN_TRUE;
924   else
925     Py_RETURN_FALSE;
926 }
927 
THPModule_setUserEnabledNNPACK(PyObject * _unused,PyObject * arg)928 PyObject* THPModule_setUserEnabledNNPACK(PyObject* _unused, PyObject* arg) {
929   HANDLE_TH_ERRORS
930   TORCH_CHECK(
931       PyBool_Check(arg),
932       "set_enabled_NNPACK expects a bool, "
933       "but got ",
934       THPUtils_typename(arg));
935   at::globalContext().setUserEnabledNNPACK(arg == Py_True);
936   Py_RETURN_NONE;
937   END_HANDLE_TH_ERRORS
938 }
939 
THPModule_userEnabledNNPACK(PyObject * _unused,PyObject * noargs)940 PyObject* THPModule_userEnabledNNPACK(PyObject* _unused, PyObject* noargs) {
941   if (at::globalContext().userEnabledNNPACK())
942     Py_RETURN_TRUE;
943   else
944     Py_RETURN_FALSE;
945 }
946 
THPModule_setWarnAlways(PyObject * _unused,PyObject * arg)947 PyObject* THPModule_setWarnAlways(PyObject* _unused, PyObject* arg) {
948   HANDLE_TH_ERRORS
949   TORCH_CHECK(
950       PyBool_Check(arg),
951       "setWarnOnlyOnce expects a bool, "
952       "but got ",
953       THPUtils_typename(arg));
954   c10::WarningUtils::set_warnAlways(arg == Py_True);
955   Py_RETURN_NONE;
956   END_HANDLE_TH_ERRORS
957 }
958 
THPModule_warnAlways(PyObject * _unused,PyObject * noargs)959 PyObject* THPModule_warnAlways(PyObject* _unused, PyObject* noargs) {
960   if (c10::WarningUtils::get_warnAlways()) {
961     Py_RETURN_TRUE;
962   }
963   Py_RETURN_FALSE;
964 }
965 
966 // Used only for testing C++ to Python warning translations.
THPModule_warn(PyObject * _unused,PyObject * noargs)967 PyObject* THPModule_warn(PyObject* _unused, PyObject* noargs) {
968   HANDLE_TH_ERRORS
969   TORCH_WARN("Test message for TORCH_WARN");
970   Py_RETURN_NONE;
971   END_HANDLE_TH_ERRORS
972 }
973 
974 // Used only for testing C++ to Python warning translations.
THPModule_warnDeprecation(PyObject * _unused,PyObject * noargs)975 PyObject* THPModule_warnDeprecation(PyObject* _unused, PyObject* noargs) {
976   HANDLE_TH_ERRORS
977   TORCH_WARN_DEPRECATION("Test message for TORCH_WARN_DEPRECATION");
978   Py_RETURN_NONE;
979   END_HANDLE_TH_ERRORS
980 }
981 
THPModule_setBenchmarkCuDNN(PyObject * _unused,PyObject * arg)982 PyObject* THPModule_setBenchmarkCuDNN(PyObject* _unused, PyObject* arg) {
983   HANDLE_TH_ERRORS
984   TORCH_CHECK(
985       PyBool_Check(arg),
986       "set_benchmark_cudnn expects a bool, "
987       "but got ",
988       THPUtils_typename(arg));
989   at::globalContext().setBenchmarkCuDNN(arg == Py_True);
990   Py_RETURN_NONE;
991   END_HANDLE_TH_ERRORS
992 }
993 
THPModule_benchmarkCuDNN(PyObject * _unused,PyObject * noargs)994 PyObject* THPModule_benchmarkCuDNN(PyObject* _unused, PyObject* noargs) {
995   if (at::globalContext().benchmarkCuDNN()) {
996     Py_RETURN_TRUE;
997   }
998   Py_RETURN_FALSE;
999 }
1000 
THPModule_setAllowTF32CuBLAS(PyObject * _unused,PyObject * arg)1001 PyObject* THPModule_setAllowTF32CuBLAS(PyObject* _unused, PyObject* arg) {
1002   HANDLE_TH_ERRORS
1003   TORCH_CHECK(
1004       PyBool_Check(arg),
1005       "set_allow_tf32_cublas expects a bool, "
1006       "but got ",
1007       THPUtils_typename(arg));
1008   at::globalContext().setAllowTF32CuBLAS(arg == Py_True);
1009   Py_RETURN_NONE;
1010   END_HANDLE_TH_ERRORS
1011 }
1012 
THPModule_allowTF32CuBLAS(PyObject * _unused,PyObject * noargs)1013 PyObject* THPModule_allowTF32CuBLAS(PyObject* _unused, PyObject* noargs) {
1014   if (at::globalContext().allowTF32CuBLAS()) {
1015     Py_RETURN_TRUE;
1016   }
1017   Py_RETURN_FALSE;
1018 }
1019 
THPModule_setAllowFP16ReductionCuBLAS(PyObject * _unused,PyObject * arg)1020 PyObject* THPModule_setAllowFP16ReductionCuBLAS(
1021     PyObject* _unused,
1022     PyObject* arg) {
1023   HANDLE_TH_ERRORS
1024   TORCH_CHECK(
1025       PyBool_Check(arg),
1026       "set_allow_fp16_reduction_cublas expects a bool, "
1027       "but got ",
1028       THPUtils_typename(arg));
1029   at::globalContext().setAllowFP16ReductionCuBLAS(arg == Py_True);
1030   Py_RETURN_NONE;
1031   END_HANDLE_TH_ERRORS
1032 }
1033 
THPModule_allowFP16ReductionCuBLAS(PyObject * _unused,PyObject * noargs)1034 PyObject* THPModule_allowFP16ReductionCuBLAS(
1035     PyObject* _unused,
1036     PyObject* noargs) {
1037   if (at::globalContext().allowFP16ReductionCuBLAS()) {
1038     Py_RETURN_TRUE;
1039   }
1040   Py_RETURN_FALSE;
1041 }
1042 
THPModule_setAllowBF16ReductionCuBLAS(PyObject * _unused,PyObject * arg)1043 PyObject* THPModule_setAllowBF16ReductionCuBLAS(
1044     PyObject* _unused,
1045     PyObject* arg) {
1046   HANDLE_TH_ERRORS
1047   TORCH_CHECK(
1048       PyBool_Check(arg),
1049       "set_allow_bf16_reduction_cublas expects a bool, "
1050       "but got ",
1051       THPUtils_typename(arg));
1052   at::globalContext().setAllowBF16ReductionCuBLAS(arg == Py_True);
1053   Py_RETURN_NONE;
1054   END_HANDLE_TH_ERRORS
1055 }
1056 
THPModule_allowBF16ReductionCuBLAS(PyObject * _unused,PyObject * noargs)1057 PyObject* THPModule_allowBF16ReductionCuBLAS(
1058     PyObject* _unused,
1059     PyObject* noargs) {
1060   if (at::globalContext().allowBF16ReductionCuBLAS()) {
1061     Py_RETURN_TRUE;
1062   }
1063   Py_RETURN_FALSE;
1064 }
1065 
THPModule_setAllowFP16ReductionCPU(PyObject * _unused,PyObject * arg)1066 PyObject* THPModule_setAllowFP16ReductionCPU(PyObject* _unused, PyObject* arg) {
1067   HANDLE_TH_ERRORS
1068   TORCH_CHECK(
1069       PyBool_Check(arg),
1070       "set_allow_fp16_reduction_cpu expects a bool, "
1071       "but got ",
1072       THPUtils_typename(arg));
1073   at::globalContext().setAllowFP16ReductionCPU(arg == Py_True);
1074   Py_RETURN_NONE;
1075   END_HANDLE_TH_ERRORS
1076 }
1077 
THPModule_allowFP16ReductionCPU(PyObject * _unused,PyObject * noargs)1078 PyObject* THPModule_allowFP16ReductionCPU(PyObject* _unused, PyObject* noargs) {
1079   if (at::globalContext().allowFP16ReductionCPU()) {
1080     Py_RETURN_TRUE;
1081   }
1082   Py_RETURN_FALSE;
1083 }
1084 
THPModule_setFlushDenormal(PyObject * _unused,PyObject * arg)1085 PyObject* THPModule_setFlushDenormal(PyObject* _unused, PyObject* arg) {
1086   HANDLE_TH_ERRORS
1087   TORCH_CHECK(
1088       PyBool_Check(arg),
1089       "flush_denormal expects a bool, "
1090       "but got ",
1091       THPUtils_typename(arg));
1092   if (!at::globalContext().setFlushDenormal(arg == Py_True)) {
1093     Py_RETURN_FALSE;
1094   };
1095   Py_RETURN_TRUE;
1096   END_HANDLE_TH_ERRORS
1097 }
1098 
THPModule_getDefaultDtype(PyObject * _unused,PyObject * arg)1099 PyObject* THPModule_getDefaultDtype(PyObject* _unused, PyObject* arg) {
1100   HANDLE_TH_ERRORS
1101   auto scalar_type = torch::tensors::get_default_scalar_type();
1102   return Py_NewRef(torch::getTHPDtype(scalar_type));
1103   END_HANDLE_TH_ERRORS
1104 }
1105 
THPModule_getDefaultDevice(PyObject * _unused,PyObject * arg)1106 PyObject* THPModule_getDefaultDevice(PyObject* _unused, PyObject* arg) {
1107   HANDLE_TH_ERRORS
1108   return THPUtils_packString(c10::DeviceTypeName(
1109       dispatchKeyToDeviceType(torch::tensors::get_default_dispatch_key()),
1110       /*lower_case=*/true));
1111   END_HANDLE_TH_ERRORS
1112 }
1113 
THPModule_setQEngine(PyObject *,PyObject * arg)1114 PyObject* THPModule_setQEngine(PyObject* /* unused */, PyObject* arg) {
1115   HANDLE_TH_ERRORS
1116   TORCH_CHECK(
1117       THPUtils_checkLong(arg),
1118       "set_qengine expects an int, "
1119       "but got ",
1120       THPUtils_typename(arg));
1121   auto qengine = THPUtils_unpackLong(arg);
1122   at::globalContext().setQEngine(static_cast<at::QEngine>(qengine));
1123   Py_RETURN_NONE;
1124   END_HANDLE_TH_ERRORS
1125 }
1126 
THPModule_qEngine(PyObject * _unused,PyObject * noargs)1127 PyObject* THPModule_qEngine(PyObject* _unused, PyObject* noargs) {
1128   return THPUtils_packInt64(
1129       static_cast<int64_t>(at::globalContext().qEngine()));
1130 }
1131 
THPModule_supportedQEngines(PyObject * _unused,PyObject * noargs)1132 PyObject* THPModule_supportedQEngines(PyObject* _unused, PyObject* noargs) {
1133   auto qengines = at::globalContext().supportedQEngines();
1134   auto list =
1135       THPObjectPtr(PyList_New(static_cast<Py_ssize_t>(qengines.size())));
1136   if (!list)
1137     return nullptr;
1138   for (const auto i : c10::irange(qengines.size())) {
1139     PyObject* i64 = THPUtils_packInt64(static_cast<int64_t>(qengines[i]));
1140     if (!i64)
1141       return nullptr;
1142     PyList_SET_ITEM(list.get(), i, i64);
1143   }
1144   return list.release();
1145 }
1146 
THPModule_isEnabledXNNPACK(PyObject * _unused,PyObject * noargs)1147 PyObject* THPModule_isEnabledXNNPACK(PyObject* _unused, PyObject* noargs) {
1148   if (at::globalContext().isXNNPACKAvailable())
1149     Py_RETURN_TRUE;
1150   else
1151     Py_RETURN_FALSE;
1152 }
1153 
THPModule_setCheckSparseTensorInvariants(PyObject * _unused,PyObject * arg)1154 PyObject* THPModule_setCheckSparseTensorInvariants(
1155     PyObject* _unused,
1156     PyObject* arg) {
1157   HANDLE_TH_ERRORS
1158   TORCH_CHECK(
1159       PyBool_Check(arg),
1160       "set_check_sparse_tensor_invariants expects a bool, "
1161       "but got ",
1162       THPUtils_typename(arg));
1163   at::globalContext().setCheckSparseTensorInvariants(arg == Py_True);
1164   Py_RETURN_NONE;
1165   END_HANDLE_TH_ERRORS
1166 }
1167 
THPModule_checkSparseTensorInvariants(PyObject * _unused,PyObject * noargs)1168 PyObject* THPModule_checkSparseTensorInvariants(
1169     PyObject* _unused,
1170     PyObject* noargs) {
1171   if (at::globalContext().checkSparseTensorInvariants())
1172     Py_RETURN_TRUE;
1173   else
1174     Py_RETURN_FALSE;
1175 }
1176 
THPModule_willEngineExecuteNode(PyObject * _unused,PyObject * arg)1177 PyObject* THPModule_willEngineExecuteNode(PyObject* _unused, PyObject* arg) {
1178   HANDLE_TH_ERRORS
1179   bool isTHPFunction = THPFunction_Check(arg);
1180   bool isTHPCppFunction = torch::autograd::THPCppFunction_Check(arg);
1181   TORCH_CHECK(
1182       isTHPFunction || isTHPCppFunction,
1183       "_will_engine_execute_node expects an grad_fn, "
1184       "but got ",
1185       THPUtils_typename(arg));
1186   const auto exec_info = torch::autograd::get_current_graph_task_exec_info();
1187   TORCH_CHECK(
1188       exec_info,
1189       "_get_should_execute_nodes should only be called during the backward pass");
1190   torch::autograd::Node* node = nullptr;
1191   std::shared_ptr<torch::autograd::Node> node_sp;
1192   if (isTHPFunction) {
1193     node_sp = ((THPFunction*)arg)->cdata.lock();
1194     node = node_sp.get();
1195   } else {
1196     node = ((torch::autograd::THPCppFunction*)arg)->cdata.get();
1197   }
1198   const auto nodes_in_graph =
1199       torch::autograd::get_current_graph_task_nodes_in_graph();
1200   bool ret = nodes_in_graph->find(node) != nodes_in_graph->end();
1201   if (ret && !exec_info->empty()) {
1202     auto it = exec_info->find(node);
1203     if (it == exec_info->end() || !it->second.should_execute()) {
1204       ret = false;
1205     } else {
1206       TORCH_CHECK(
1207           !(node->topological_nr() == 0 && it->second.captures_),
1208           "A leaf node was passed to _will_engine_execute_node but we are "
1209           "currently running autograd.grad(). This is currently not supported.");
1210     }
1211   }
1212   if (ret) {
1213     Py_RETURN_TRUE;
1214   } else {
1215     Py_RETURN_FALSE;
1216   }
1217   END_HANDLE_TH_ERRORS
1218 }
1219 
THPModule_getCurrentGraphTaskExecutionOrder(PyObject * _unused,PyObject * noargs)1220 PyObject* THPModule_getCurrentGraphTaskExecutionOrder(
1221     PyObject* _unused,
1222     PyObject* noargs) {
1223   HANDLE_TH_ERRORS
1224   std::vector<torch::autograd::Node*> nodes =
1225       torch::autograd::get_current_graph_task_execution_order();
1226   TORCH_CHECK(
1227       !nodes.empty(),
1228       "_current_graph_task_execution_order should only be called during the backward pass");
1229   auto list = THPObjectPtr(PyList_New(static_cast<Py_ssize_t>(nodes.size())));
1230   if (!list)
1231     return nullptr;
1232   for (const auto i : c10::irange(nodes.size())) {
1233     // This node is guaranteed to be alive since the backward is still running
1234     PyObject* pyobj_node =
1235         torch::autograd::functionToPyObject(nodes[i]->getptr());
1236     PyList_SET_ITEM(list.get(), i, pyobj_node);
1237   }
1238   return list.release();
1239   END_HANDLE_TH_ERRORS
1240 }
1241 
THPModule_getCurrentGraphTaskId(PyObject * _unused,PyObject * noargs)1242 PyObject* THPModule_getCurrentGraphTaskId(PyObject* _unused, PyObject* noargs) {
1243   HANDLE_TH_ERRORS
1244   return THPUtils_packInt64(torch::autograd::get_current_graph_task_id());
1245   END_HANDLE_TH_ERRORS
1246 }
1247 
THPModule_getCurrentNode(PyObject * _unused,PyObject * noargs)1248 PyObject* THPModule_getCurrentNode(PyObject* _unused, PyObject* noargs) {
1249   HANDLE_TH_ERRORS
1250   return torch::autograd::functionToPyObject(
1251       torch::autograd::get_current_node());
1252   END_HANDLE_TH_ERRORS
1253 }
1254 
THPModule_setDefaultMobileCPUAllocator(PyObject * _unused,PyObject * noargs)1255 PyObject* THPModule_setDefaultMobileCPUAllocator(
1256     PyObject* _unused,
1257     PyObject* noargs) {
1258   HANDLE_TH_ERRORS
1259   at::globalContext().setDefaultMobileCPUAllocator();
1260   Py_RETURN_NONE;
1261   END_HANDLE_TH_ERRORS
1262 }
1263 
THPModule_unsetDefaultMobileCPUAllocator(PyObject * _unused,PyObject * noargs)1264 PyObject* THPModule_unsetDefaultMobileCPUAllocator(
1265     PyObject* _unused,
1266     PyObject* noargs) {
1267   HANDLE_TH_ERRORS
1268   at::globalContext().unsetDefaultMobileCPUAllocator();
1269   Py_RETURN_NONE;
1270   END_HANDLE_TH_ERRORS
1271 }
1272 
THPModule_vmapmode_increment_nesting(PyObject * _unused,PyObject * arg)1273 static PyObject* THPModule_vmapmode_increment_nesting(
1274     PyObject* _unused,
1275     PyObject* arg) {
1276   HANDLE_TH_ERRORS
1277   return THPUtils_packInt64(at::impl::VmapMode::increment_nesting());
1278   END_HANDLE_TH_ERRORS
1279 }
1280 
THPModule_vmapmode_decrement_nesting(PyObject * _unused,PyObject * arg)1281 static PyObject* THPModule_vmapmode_decrement_nesting(
1282     PyObject* _unused,
1283     PyObject* arg) {
1284   HANDLE_TH_ERRORS
1285   return THPUtils_packInt64(at::impl::VmapMode::decrement_nesting());
1286   END_HANDLE_TH_ERRORS
1287 }
1288 
THPModule_set_display_vmap_fallback_warnings_mode(PyObject * _unused,PyObject * arg)1289 static PyObject* THPModule_set_display_vmap_fallback_warnings_mode(
1290     PyObject* _unused,
1291     PyObject* arg) {
1292   HANDLE_TH_ERRORS
1293   TORCH_CHECK(
1294       PyBool_Check(arg),
1295       "enabled must be a bool, "
1296       "but got ",
1297       THPUtils_typename(arg));
1298   at::globalContext().setDisplayVmapFallbackWarnings(arg == Py_True);
1299   Py_RETURN_NONE;
1300   END_HANDLE_TH_ERRORS
1301 }
1302 
THPModule_are_vmap_fallback_warnings_enabled(PyObject * _unused,PyObject * arg)1303 static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
1304     PyObject* _unused,
1305     PyObject* arg) {
1306   HANDLE_TH_ERRORS
1307   if (at::globalContext().areVmapFallbackWarningsEnabled()) {
1308     Py_RETURN_TRUE;
1309   } else {
1310     Py_RETURN_FALSE;
1311   }
1312   END_HANDLE_TH_ERRORS
1313 }
1314 
1315 static PyMethodDef TorchMethods[] = { // NOLINT
1316     {"_initExtension", THPModule_initExtension, METH_O, nullptr},
1317     {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
1318     {"_add_docstr", THPModule_addDocStr, METH_VARARGS, nullptr},
1319     {"_swap_tensor_impl", THPModule_swap_tensor_impl, METH_VARARGS, nullptr},
1320     {"_init_names", THPModule_initNames, METH_O, nullptr},
1321     {"_has_distributed", THPModule_hasDistributed, METH_NOARGS, nullptr},
1322     {"_set_default_tensor_type",
1323      THPModule_setDefaultTensorType,
1324      METH_O,
1325      nullptr},
1326     {"_set_default_dtype", THPModule_setDefaultDtype, METH_O, nullptr},
1327     {"_infer_size", THPModule_inferSize, METH_VARARGS, nullptr},
1328     {"_abort", THPModule_abort, METH_NOARGS, nullptr},
1329     {"_crash_if_csrc_asan", THPModule_crashIfCsrcASAN, METH_O, nullptr},
1330     {"_crash_if_csrc_ubsan", THPModule_crashIfCsrcUBSAN, METH_O, nullptr},
1331     {"_crash_if_vptr_ubsan", THPModule_crashIfvptrUBSAN, METH_NOARGS, nullptr},
1332     {"_crash_if_aten_asan", THPModule_crashIfATenASAN, METH_O, nullptr},
1333     {"_crash_if_debug_asserts_fail",
1334      THPModule_crashIfDebugAssertsFail,
1335      METH_O,
1336      nullptr},
1337     {"_show_config", THPModule_showConfig, METH_NOARGS, nullptr},
1338     {"_cxx_flags", THPModule_cxxFlags, METH_NOARGS, nullptr},
1339     {"_parallel_info", THPModule_parallelInfo, METH_NOARGS, nullptr},
1340     {"_get_cpu_capability", THPModule_getCpuCapability, METH_NOARGS, nullptr},
1341     {"_set_backcompat_broadcast_warn",
1342      THPModule_setBackcompatBroadcastWarn,
1343      METH_O,
1344      nullptr},
1345     {"_get_backcompat_broadcast_warn",
1346      THPModule_getBackcompatBroadcastWarn,
1347      METH_NOARGS,
1348      nullptr},
1349     {"_set_backcompat_keepdim_warn",
1350      THPModule_setBackcompatKeepdimWarn,
1351      METH_O,
1352      nullptr},
1353     {"_get_backcompat_keepdim_warn",
1354      THPModule_getBackcompatKeepdimWarn,
1355      METH_NOARGS,
1356      nullptr},
1357     {"get_num_threads", THPModule_getNumThreads, METH_NOARGS, nullptr},
1358     {"set_num_threads", THPModule_setNumThreads, METH_O, nullptr},
1359     {"get_num_interop_threads",
1360      THPModule_getNumInteropThreads,
1361      METH_NOARGS,
1362      nullptr},
1363     {"set_num_interop_threads",
1364      THPModule_setNumInteropThreads,
1365      METH_O,
1366      nullptr},
1367     {"_get_flash_sdp_enabled",
1368      THPModule_userEnabledFlashSDP,
1369      METH_NOARGS,
1370      nullptr},
1371     {"_set_sdp_use_flash", THPModule_setSDPUseFlash, METH_O, nullptr},
1372     {"_get_mem_efficient_sdp_enabled",
1373      userEnabledMemEfficientSDP,
1374      METH_NOARGS,
1375      nullptr},
1376     {"_set_sdp_use_mem_efficient",
1377      THPModule_setSDPUseMemEfficient,
1378      METH_O,
1379      nullptr},
1380     {"_get_math_sdp_enabled",
1381      THPModule_userEnabledMathSDP,
1382      METH_NOARGS,
1383      nullptr},
1384     {"_set_sdp_use_math", THPModule_setSDPUseMath, METH_O, nullptr},
1385     {"_get_math_sdp_allow_fp16_bf16_reduction",
1386      THPModule_allowFP16BF16ReductionMathSDP,
1387      METH_NOARGS,
1388      nullptr},
1389     {"_set_math_sdp_allow_fp16_bf16_reduction",
1390      THPModule_setAllowFP16BF16ReductionMathSDP,
1391      METH_O,
1392      nullptr},
1393     {"_get_overrideable_sdp_enabled",
1394      THPModule_userEnabledOverrideableSDP,
1395      METH_NOARGS,
1396      nullptr},
1397     {"_set_sdp_use_overrideable",
1398      THPModule_setSDPUseOverrideable,
1399      METH_O,
1400      nullptr},
1401     {"_get_cudnn_sdp_enabled",
1402      THPModule_userEnabledCuDNNSDP,
1403      METH_NOARGS,
1404      nullptr},
1405     {"_set_sdp_use_cudnn", THPModule_setSDPUseCuDNN, METH_O, nullptr},
1406     {"_get_cudnn_enabled", THPModule_userEnabledCuDNN, METH_NOARGS, nullptr},
1407     {"_set_cudnn_enabled", THPModule_setUserEnabledCuDNN, METH_O, nullptr},
1408     {"_get_mkldnn_enabled", THPModule_userEnabledMkldnn, METH_NOARGS, nullptr},
1409     {"_set_mkldnn_enabled", THPModule_setUserEnabledMkldnn, METH_O, nullptr},
1410     {"_get_cudnn_allow_tf32", THPModule_allowTF32CuDNN, METH_NOARGS, nullptr},
1411     {"_set_cudnn_allow_tf32", THPModule_setAllowTF32CuDNN, METH_O, nullptr},
1412     {"_get_cudnn_benchmark", THPModule_benchmarkCuDNN, METH_NOARGS, nullptr},
1413     {"_set_cudnn_benchmark", THPModule_setBenchmarkCuDNN, METH_O, nullptr},
1414     {"_get_cudnn_deterministic",
1415      THPModule_deterministicCuDNN,
1416      METH_NOARGS,
1417      nullptr},
1418     {"_set_cudnn_deterministic",
1419      THPModule_setDeterministicCuDNN,
1420      METH_O,
1421      nullptr},
1422     {"_get_mkldnn_deterministic",
1423      THPModule_deterministicMkldnn,
1424      METH_NOARGS,
1425      nullptr},
1426     {"_set_mkldnn_deterministic",
1427      THPModule_setDeterministicMkldnn,
1428      METH_O,
1429      nullptr},
1430     {"_get_deterministic_algorithms",
1431      THPModule_deterministicAlgorithms,
1432      METH_NOARGS,
1433      nullptr},
1434     {"_get_deterministic_algorithms_warn_only",
1435      THPModule_deterministicAlgorithmsWarnOnly,
1436      METH_NOARGS,
1437      nullptr},
1438     {"_set_deterministic_algorithms",
1439      castPyCFunctionWithKeywords(THPModule_setDeterministicAlgorithms),
1440      METH_VARARGS | METH_KEYWORDS,
1441      nullptr},
1442     {"_get_deterministic_fill_uninitialized_memory",
1443      THPModule_deterministicFillUninitializedMemory,
1444      METH_NOARGS,
1445      nullptr},
1446     {"_set_deterministic_fill_uninitialized_memory",
1447      THPModule_setDeterministicFillUninitializedMemory,
1448      METH_O,
1449      nullptr},
1450     {"_get_nnpack_enabled", THPModule_userEnabledNNPACK, METH_NOARGS, nullptr},
1451     {"_set_nnpack_enabled", THPModule_setUserEnabledNNPACK, METH_O, nullptr},
1452     {"_get_warnAlways", THPModule_warnAlways, METH_NOARGS, nullptr},
1453     {"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
1454     {"_warn", THPModule_warn, METH_NOARGS, nullptr},
1455     {"_warn_deprecation", THPModule_warnDeprecation, METH_NOARGS, nullptr},
1456     {"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr},
1457     {"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr},
1458     {"_get_float32_matmul_precision",
1459      THPModule_float32MatmulPrecision,
1460      METH_NOARGS,
1461      nullptr},
1462     {"_set_float32_matmul_precision",
1463      THPModule_setFloat32MatmulPrecision,
1464      METH_O,
1465      nullptr},
1466     {"_get_cublas_allow_fp16_reduced_precision_reduction",
1467      THPModule_allowFP16ReductionCuBLAS,
1468      METH_NOARGS,
1469      nullptr},
1470     {"_set_cublas_allow_fp16_reduced_precision_reduction",
1471      THPModule_setAllowFP16ReductionCuBLAS,
1472      METH_O,
1473      nullptr},
1474     {"_get_cublas_allow_bf16_reduced_precision_reduction",
1475      THPModule_allowBF16ReductionCuBLAS,
1476      METH_NOARGS,
1477      nullptr},
1478     {"_set_cublas_allow_bf16_reduced_precision_reduction",
1479      THPModule_setAllowBF16ReductionCuBLAS,
1480      METH_O,
1481      nullptr},
1482     {"_get_cpu_allow_fp16_reduced_precision_reduction",
1483      THPModule_allowFP16ReductionCPU,
1484      METH_NOARGS,
1485      nullptr},
1486     {"_set_cpu_allow_fp16_reduced_precision_reduction",
1487      THPModule_setAllowFP16ReductionCPU,
1488      METH_O,
1489      nullptr},
1490     {"_vmapmode_increment_nesting",
1491      THPModule_vmapmode_increment_nesting,
1492      METH_NOARGS,
1493      nullptr},
1494     {"_vmapmode_decrement_nesting",
1495      THPModule_vmapmode_decrement_nesting,
1496      METH_NOARGS,
1497      nullptr},
1498     {"_debug_only_display_vmap_fallback_warnings",
1499      THPModule_set_display_vmap_fallback_warnings_mode,
1500      METH_O,
1501      nullptr},
1502     {"_debug_only_are_vmap_fallback_warnings_enabled",
1503      THPModule_are_vmap_fallback_warnings_enabled,
1504      METH_NOARGS,
1505      nullptr},
1506     {"_to_dlpack", THPModule_toDLPack, METH_O, nullptr},
1507     {"_from_dlpack", THPModule_fromDLPack, METH_O, nullptr},
1508     {"_get_cpp_backtrace", THModule_getCppBacktrace, METH_VARARGS, nullptr},
1509     {"_rename_privateuse1_backend",
1510      THModule_rename_privateuse1_backend,
1511      METH_O,
1512      nullptr},
1513     {"_get_privateuse1_backend_name",
1514      THModule_get_privateuse1_backend_name,
1515      METH_NOARGS,
1516      nullptr},
1517     {"set_flush_denormal", THPModule_setFlushDenormal, METH_O, nullptr},
1518     {"get_default_dtype", THPModule_getDefaultDtype, METH_NOARGS, nullptr},
1519     {"_get_default_device", THPModule_getDefaultDevice, METH_NOARGS, nullptr},
1520     {"_get_qengine", THPModule_qEngine, METH_NOARGS, nullptr},
1521     {"_set_qengine", THPModule_setQEngine, METH_O, nullptr},
1522     {"_supported_qengines", THPModule_supportedQEngines, METH_NOARGS, nullptr},
1523     {"_is_xnnpack_enabled", THPModule_isEnabledXNNPACK, METH_NOARGS, nullptr},
1524     {"_set_check_sparse_tensor_invariants",
1525      THPModule_setCheckSparseTensorInvariants,
1526      METH_O,
1527      nullptr},
1528     {"_check_sparse_tensor_invariants",
1529      THPModule_checkSparseTensorInvariants,
1530      METH_NOARGS,
1531      nullptr},
1532     {"_will_engine_execute_node",
1533      THPModule_willEngineExecuteNode,
1534      METH_O,
1535      nullptr},
1536     {"_current_graph_task_execution_order",
1537      THPModule_getCurrentGraphTaskExecutionOrder,
1538      METH_NOARGS,
1539      nullptr},
1540     {"_current_graph_task_id",
1541      THPModule_getCurrentGraphTaskId,
1542      METH_NOARGS,
1543      nullptr},
1544     {"_current_autograd_node", THPModule_getCurrentNode, METH_NOARGS, nullptr},
1545     {"_set_default_mobile_cpu_allocator",
1546      THPModule_setDefaultMobileCPUAllocator,
1547      METH_NOARGS,
1548      nullptr},
1549     {"_unset_default_mobile_cpu_allocator",
1550      THPModule_unsetDefaultMobileCPUAllocator,
1551      METH_NOARGS,
1552      nullptr},
1553     {"_is_torch_function_enabled",
1554      THPModule_isEnabledTorchFunction,
1555      METH_NOARGS,
1556      nullptr},
1557     {"_is_torch_function_all_disabled",
1558      THPModule_isAllDisabledTorchFunction,
1559      METH_NOARGS,
1560      nullptr},
1561     {"_disabled_torch_function_impl",
1562      THPModule_disable_torch_function,
1563      METH_VARARGS,
1564      nullptr},
1565     {"_disabled_torch_dispatch_impl",
1566      THPModule_disable_torch_dispatch,
1567      METH_VARARGS,
1568      nullptr},
1569     {"_has_torch_function", THPModule_has_torch_function, METH_O, nullptr},
1570     {"_has_torch_function_unary",
1571      THPModule_has_torch_function_unary,
1572      METH_O,
1573      nullptr},
1574     {"_has_torch_function_variadic",
1575      (PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
1576      METH_FASTCALL,
1577      nullptr},
1578     {nullptr, nullptr, 0, nullptr}};
1579 
1580 void THCPStream_init(PyObject* module);
1581 void THCPEvent_init(PyObject* module);
1582 void THCPGraph_init(PyObject* module);
1583 void THCPMemPool_init(PyObject* module);
1584 
1585 #ifdef USE_CUDA
1586 PyMethodDef* THCPModule_methods();
1587 namespace torch::cuda {
1588 void initModule(PyObject* module);
1589 } // namespace torch::cuda
1590 #endif
1591 
1592 #ifdef USE_XPU
1593 PyMethodDef* THXPModule_methods();
1594 void THXPStream_init(PyObject* module);
1595 void THXPEvent_init(PyObject* module);
1596 namespace torch::xpu {
1597 void initModule(PyObject* module);
1598 } // namespace torch::xpu
1599 #endif
1600 
1601 #ifdef USE_ITT
1602 namespace torch::profiler {
1603 void initIttBindings(PyObject* module);
1604 } // namespace torch::profiler
1605 #endif
1606 
1607 static std::vector<PyMethodDef> methods;
1608 
1609 // In Python we can't use the trick of C10_LOG_API_USAGE_ONCE
1610 // Guaranteed to be invoked from Python under GIL, no locking on map needed
LogAPIUsageOnceFromPython(const std::string & event)1611 static void LogAPIUsageOnceFromPython(const std::string& event) {
1612   static std::unordered_set<std::string> seen;
1613   if (!seen.count(event)) {
1614     seen.insert(event);
1615     c10::LogAPIUsage(event);
1616   }
1617 }
1618 
LogAPIUsageMetadataFromPython(const std::string & event,const std::map<std::string,std::string> & metadata_map)1619 static void LogAPIUsageMetadataFromPython(
1620     const std::string& event,
1621     const std::map<std::string, std::string>& metadata_map) {
1622   c10::LogAPIUsageMetadata(event, metadata_map);
1623 }
1624 
1625 // Weak reference to tensor, used to test a tensor isn't leaked
1626 class WeakTensorRef {
1627   c10::weak_intrusive_ptr<c10::TensorImpl> weakref_;
1628 
1629  public:
WeakTensorRef(const at::Tensor & t)1630   WeakTensorRef(const at::Tensor& t) : weakref_(t.getIntrusivePtr()) {}
1631 
expired()1632   bool expired() {
1633     return weakref_.expired();
1634   }
1635 };
1636 
1637 extern "C" C10_EXPORT PyObject* initModule();
1638 // separate decl and defn for msvc error C2491
initModule()1639 PyObject* initModule() {
1640   HANDLE_TH_ERRORS
1641 
1642   c10::initLogging();
1643   c10::set_terminate_handler();
1644   at::internal::lazy_init_num_threads();
1645 
1646   C10_LOG_API_USAGE_ONCE("torch.python.import");
1647 
1648 #define ASSERT_TRUE(cmd) \
1649   if (!(cmd))            \
1650   return nullptr
1651 
1652   THPUtils_addPyMethodDefs(methods, TorchMethods);
1653   THPUtils_addPyMethodDefs(methods, DataLoaderMethods);
1654   THPUtils_addPyMethodDefs(methods, torch::autograd::python_functions());
1655   THPUtils_addPyMethodDefs(methods, torch::multiprocessing::python_functions());
1656   THPUtils_addPyMethodDefs(methods, torch::mps::python_functions());
1657 #ifdef USE_CUDA
1658   THPUtils_addPyMethodDefs(methods, THCPModule_methods());
1659 #endif
1660 #ifdef USE_XPU
1661   THPUtils_addPyMethodDefs(methods, THXPModule_methods());
1662 #endif
1663 #if defined(USE_DISTRIBUTED) && defined(USE_C10D)
1664   THPUtils_addPyMethodDefs(
1665       methods, torch::distributed::c10d::python_functions());
1666 #ifndef _WIN32
1667   THPUtils_addPyMethodDefs(
1668       methods, torch::distributed::rpc::python_functions());
1669   THPUtils_addPyMethodDefs(
1670       methods, torch::distributed::autograd::python_functions());
1671   THPUtils_addPyMethodDefs(
1672       methods, torch::distributed::rpc::testing::python_functions());
1673 #endif
1674 #endif
1675 
1676   static struct PyModuleDef torchmodule = {
1677       PyModuleDef_HEAD_INIT, "torch._C", nullptr, -1, methods.data()};
1678   module = PyModule_Create(&torchmodule);
1679   ASSERT_TRUE(module);
1680   ASSERT_TRUE(THPGenerator_init(module));
1681   ASSERT_TRUE(THPException_init(module));
1682   THPSize_init(module);
1683   THPDtype_init(module);
1684   THPDTypeInfo_init(module);
1685   THPLayout_init(module);
1686   THPMemoryFormat_init(module);
1687   THPQScheme_init(module);
1688   THPDevice_init(module);
1689   THPStream_init(module);
1690   THPEvent_init(module);
1691   NodeBase_init(module);
1692   NodeIter_init(module);
1693   ASSERT_TRUE(THPVariable_initModule(module));
1694   ASSERT_TRUE(THPFunction_initModule(module));
1695   ASSERT_TRUE(THPEngine_initModule(module));
1696   // NOTE: We need to be able to access OperatorExportTypes from ONNX for use in
1697   // the export side of JIT, so this ONNX init needs to appear before the JIT
1698   // init.
1699   torch::onnx::initONNXBindings(module);
1700   torch::autograd::initEnumTag(module);
1701   torch::jit::initJITBindings(module);
1702   torch::monitor::initMonitorBindings(module);
1703   torch::impl::dispatch::initDispatchBindings(module);
1704   torch::dynamo::initDynamoBindings(module);
1705   torch::functorch::impl::initFuncTorchBindings(module);
1706   torch::throughput_benchmark::initThroughputBenchmarkBindings(module);
1707   torch::autograd::initReturnTypes(module);
1708   torch::autograd::initNNFunctions(module);
1709   torch::autograd::initFFTFunctions(module);
1710   torch::autograd::initLinalgFunctions(module);
1711   torch::autograd::initNestedFunctions(module);
1712   torch::autograd::initSparseFunctions(module);
1713   torch::autograd::initSpecialFunctions(module);
1714   torch::autograd::init_legacy_variable(module);
1715   torch::profiler::initPythonBindings(module);
1716   torch::python::init_bindings(module);
1717   torch::lazy::initLazyBindings(module);
1718   torch::inductor::initAOTIRunnerBindings(module);
1719 #ifdef USE_ITT
1720   torch::profiler::initIttBindings(module);
1721 #endif
1722 #ifdef USE_CUDA
1723   torch::cuda::initModule(module);
1724 #endif
1725 #ifdef USE_XPU
1726   torch::xpu::initModule(module);
1727 #endif
1728   torch::mtia::initModule(module);
1729   torch::cpu::initModule(module);
1730   torch::instruction_counter::initModule(module);
1731   torch::initVerboseBindings(module);
1732   ASSERT_TRUE(THPStorage_init(module));
1733 
1734 #ifdef USE_CUDA
1735   // This will only initialise base classes and attach them to library namespace
1736   // They won't be ready for real usage until importing cuda module, that will
1737   // complete the process (but it defines Python classes before calling back
1738   // into C, so these lines have to execute first)..
1739   THCPStream_init(module);
1740   THCPEvent_init(module);
1741   THCPGraph_init(module);
1742   THCPMemPool_init(module);
1743 #endif
1744 
1745 #ifdef USE_XPU
1746   THXPStream_init(module);
1747   THXPEvent_init(module);
1748 #endif
1749 
1750   auto set_module_attr =
1751       [&](const char* name, PyObject* v, bool incref = true) {
1752         // PyModule_AddObject steals reference
1753         if (incref) {
1754           Py_INCREF(v);
1755         }
1756 
1757         int ret = PyModule_AddObject(module, name, v);
1758         if (ret != 0) {
1759           Py_DECREF(v);
1760         }
1761 
1762         return ret == 0;
1763       };
1764 
1765 #if defined(USE_CUDNN) || defined(USE_ROCM)
1766   PyObject* has_cudnn = Py_True;
1767 #else
1768   PyObject* has_cudnn = Py_False;
1769 #endif
1770   ASSERT_TRUE(set_module_attr("_has_cudnn", has_cudnn));
1771 
1772 #if defined(USE_CUSPARSELT)
1773   PyObject* has_cusparselt = Py_True;
1774 #else
1775   PyObject* has_cusparselt = Py_False;
1776 #endif
1777   ASSERT_TRUE(set_module_attr("_has_cusparselt", has_cusparselt));
1778 
1779 #if AT_MKL_ENABLED() || AT_POCKETFFT_ENABLED()
1780   PyObject* has_spectral = Py_True;
1781 #else
1782   PyObject* has_spectral = Py_False;
1783 #endif
1784   ASSERT_TRUE(set_module_attr("has_spectral", has_spectral));
1785 
1786   // force ATen to initialize because it handles
1787   // setting up TH Errors so that they throw C++ exceptions
1788   at::init();
1789 
1790   // Automatically translate errors thrown from pybind11 functions
1791   py::register_exception_translator([](std::exception_ptr e) { // NOLINT
1792     try {
1793       if (e) {
1794         std::rethrow_exception(e);
1795       }
1796     }
1797     CATCH_TH_ERRORS()
1798   });
1799 
1800   auto py_module = py::reinterpret_borrow<py::module>(module);
1801   py_module.def("_demangle", &c10::demangle);
1802   py_module.def("_log_api_usage_once", &LogAPIUsageOnceFromPython);
1803   py_module.def("_log_api_usage_metadata", &LogAPIUsageMetadataFromPython);
1804 
1805   py_module.def("vitals_enabled", &at::vitals::torchVitalEnabled);
1806   py_module.def(
1807       "set_vital",
1808       [](const std::string& vital,
1809          const std::string& attr,
1810          const std::string& value) {
1811         return at::vitals::VitalsAPI.setVital(vital, attr, value);
1812       });
1813   py_module.def(
1814       "read_vitals", []() { return at::vitals::VitalsAPI.readVitals(); });
1815 
1816   py_module.def(
1817       "init_num_threads",
1818       torch::wrap_pybind_function(at::init_num_threads),
1819       R"(
1820 init_num_threads()
1821 
1822 Initializes the number of parallel threads used on the current thread.
1823 
1824 Call this whenever a new thread is created in order to propagate values from
1825 :func:`torch.set_num_threads` onto the new thread.
1826 )");
1827 
1828   py_module.def("_set_cached_tensors_enabled", [](bool enabled) {
1829     at::caching::set_cached_tensors_enabled(enabled);
1830   });
1831 
1832   py_module.def("_add_cached_tensor", [](const at::Tensor& t) {
1833     at::caching::add_cached_tensor(t);
1834   });
1835 
1836   py_module.def("_remove_cached_tensor", [](const at::Tensor& t) {
1837     at::caching::remove_cached_tensor(t);
1838   });
1839 
1840   py_module.def("_is_cached_tensor", [](const at::Tensor& t) {
1841     return at::caching::is_cached_tensor(t);
1842   });
1843 
1844   ASSERT_TRUE(
1845       set_module_attr("has_openmp", at::hasOpenMP() ? Py_True : Py_False));
1846   ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
1847   ASSERT_TRUE(
1848       set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
1849 
1850   py_module.def("_valgrind_supported_platform", []() {
1851 #if defined(USE_VALGRIND)
1852     return true;
1853 #else
1854       return false;
1855 #endif
1856   });
1857 
1858   py_module.def("_valgrind_toggle", []() {
1859 #if defined(USE_VALGRIND)
1860     CALLGRIND_TOGGLE_COLLECT;
1861 #else
1862       TORCH_CHECK(false, "Valgrind is not supported.");
1863 #endif
1864   });
1865 
1866   py_module.def("_valgrind_toggle_and_dump_stats", []() {
1867 #if defined(USE_VALGRIND)
1868     // NB: If we don't toggle collect around dump stats, callgrind_annotate
1869     //     won't process the results correctly. Specifically,
1870     //     `callgrind_annotate --inclusive=no` will be almost completely empty.
1871     CALLGRIND_TOGGLE_COLLECT;
1872     CALLGRIND_DUMP_STATS;
1873 #else
1874       TORCH_CHECK(false, "Valgrind is not supported.");
1875 #endif
1876   });
1877 
1878   py::class_<WeakTensorRef>(py_module, "_WeakTensorRef")
1879       .def(py::init([](py::object tensor) {
1880         return WeakTensorRef(THPVariable_Unpack(tensor.ptr()));
1881       }))
1882       .def("expired", &WeakTensorRef::expired);
1883 
1884   py::enum_<at::native::ConvBackend>(py_module, "_ConvBackend")
1885       .value("CudaDepthwise2d", at::native::ConvBackend::CudaDepthwise2d)
1886       .value("CudaDepthwise3d", at::native::ConvBackend::CudaDepthwise3d)
1887       .value("Cudnn", at::native::ConvBackend::Cudnn)
1888       .value("CudnnTranspose", at::native::ConvBackend::CudnnTranspose)
1889       .value("Empty", at::native::ConvBackend::Empty)
1890       .value("Miopen", at::native::ConvBackend::Miopen)
1891       .value("MiopenDepthwise", at::native::ConvBackend::MiopenDepthwise)
1892       .value("MiopenTranspose", at::native::ConvBackend::MiopenTranspose)
1893       .value("Mkldnn", at::native::ConvBackend::Mkldnn)
1894       .value("MkldnnEmpty", at::native::ConvBackend::MkldnnEmpty)
1895       .value("NnpackSpatial", at::native::ConvBackend::NnpackSpatial)
1896       .value("Overrideable", at::native::ConvBackend::Overrideable)
1897       .value("Slow2d", at::native::ConvBackend::Slow2d)
1898       .value("Slow3d", at::native::ConvBackend::Slow3d)
1899       .value("SlowDilated2d", at::native::ConvBackend::SlowDilated2d)
1900       .value("SlowDilated3d", at::native::ConvBackend::SlowDilated3d)
1901       .value("SlowTranspose2d", at::native::ConvBackend::SlowTranspose2d)
1902       .value("SlowTranspose3d", at::native::ConvBackend::SlowTranspose3d)
1903       .value(
1904           "Winograd3x3Depthwise", at::native::ConvBackend::Winograd3x3Depthwise)
1905       .value("Xnnpack2d", at::native::ConvBackend::Xnnpack2d)
1906       .value("Mps", at::native::ConvBackend::Mps)
1907       .value("MpsTranspose,", at::native::ConvBackend::MpsTranspose);
1908 
1909   py_module.def(
1910       "_select_conv_backend",
1911       [](const at::Tensor& input,
1912          const at::Tensor& weight,
1913          const std::optional<at::Tensor>& bias_opt,
1914          at::SymIntArrayRef stride_,
1915          at::SymIntArrayRef padding_,
1916          at::SymIntArrayRef dilation_,
1917          bool transposed_,
1918          at::SymIntArrayRef output_padding_,
1919          c10::SymInt groups_) {
1920         return at::native::select_conv_backend(
1921             input,
1922             weight,
1923             bias_opt,
1924             stride_,
1925             padding_,
1926             dilation_,
1927             transposed_,
1928             output_padding_,
1929             std::move(groups_),
1930             std::nullopt);
1931       },
1932       py::arg("input"),
1933       py::arg("weight"),
1934       py::arg("bias"),
1935       py::arg("stride"),
1936       py::arg("padding"),
1937       py::arg("dilation"),
1938       py::arg("transposed"),
1939       py::arg("output_padding"),
1940       py::arg("groups"));
1941 
1942   // overload for bias_sizes_opt/backward TODO: figure out default value
1943   py_module.def(
1944       "_select_conv_backend",
1945       [](const at::Tensor& input,
1946          const at::Tensor& weight,
1947          const std::optional<at::Tensor>& bias,
1948          at::SymIntArrayRef stride_,
1949          at::SymIntArrayRef padding_,
1950          at::SymIntArrayRef dilation_,
1951          bool transposed_,
1952          at::SymIntArrayRef output_padding_,
1953          c10::SymInt groups_,
1954          std::optional<std::vector<c10::SymInt>> bias_sizes_opt) {
1955         c10::OptionalArrayRef<c10::SymInt> ref = std::nullopt;
1956         if (bias_sizes_opt) {
1957           ref = (*bias_sizes_opt);
1958         }
1959         return at::native::select_conv_backend(
1960             input,
1961             weight,
1962             bias,
1963             stride_,
1964             padding_,
1965             dilation_,
1966             transposed_,
1967             output_padding_,
1968             std::move(groups_),
1969             ref);
1970       },
1971       py::arg("input"),
1972       py::arg("weight"),
1973       py::arg("bias"),
1974       py::arg("stride"),
1975       py::arg("padding"),
1976       py::arg("dilation"),
1977       py::arg("transposed"),
1978       py::arg("output_padding"),
1979       py::arg("groups"),
1980       py::arg("bias_sizes"));
1981 
1982   py_module.def(
1983       "_conv_determine_backend_memory_format",
1984       at::native::_determine_backend_memory_format);
1985 
1986   ////////////////////////////////////////////////////////////////////////////////
1987   // Scaled Dot Product Attention utilities
1988   ////////////////////////////////////////////////////////////////////////////////
1989   py::class_<sdp::sdp_params>(py_module, "_SDPAParams")
1990       .def(py::init([](at::Tensor const& query,
1991                        at::Tensor const& key,
1992                        at::Tensor const& value,
1993                        std::optional<at::Tensor> attn_mask,
1994                        double dropout,
1995                        bool is_causal,
1996                        bool enable_gqa) {
1997         return sdp::sdp_params{
1998             query,
1999             key,
2000             value,
2001             std::move(attn_mask),
2002             dropout,
2003             is_causal,
2004             enable_gqa};
2005       }))
2006       .def_readonly("query", &sdp::sdp_params::query)
2007       .def_readonly("key", &sdp::sdp_params::key)
2008       .def_readonly("value", &sdp::sdp_params::value)
2009       .def_readonly("attn_mask", &sdp::sdp_params::attn_mask)
2010       .def_readonly("dropout", &sdp::sdp_params::dropout)
2011       .def_readonly("is_causal", &sdp::sdp_params::is_causal)
2012       .def_readonly("enable_gqa", &sdp::sdp_params::enable_gqa);
2013 
2014   py::enum_<sdp::SDPBackend>(
2015       py_module,
2016       "_SDPBackend",
2017       "An enum-like class that contains the different backends for scaled dot product attention.\n\n... warning:: This class is in beta and subject to change.\n\n"
2018       "This backend class is designed to be used with the sdpa_kernel context manager."
2019       "See :func: torch.nn.attention.sdpa_kernel for more details.")
2020       .value("ERROR", sdp::SDPBackend::error)
2021       .value("MATH", sdp::SDPBackend::math)
2022       .value("FLASH_ATTENTION", sdp::SDPBackend::flash_attention)
2023       .value("EFFICIENT_ATTENTION", sdp::SDPBackend::efficient_attention)
2024       .value("CUDNN_ATTENTION", sdp::SDPBackend::cudnn_attention)
2025       .value("OVERRIDEABLE", sdp::SDPBackend::overrideable);
2026 
2027   py_module.def("_is_flash_attention_available", []() {
2028 #ifdef USE_CUDA
2029     return sdp::is_flash_attention_available();
2030 #else
2031     return false;
2032 #endif
2033   });
2034   py_module.def(
2035       "_can_use_flash_attention",
2036       [](const sdp::sdp_params& params, bool debug) {
2037 #ifdef USE_CUDA
2038         return sdp::can_use_flash_attention(params, debug);
2039 #else
2040         return false;
2041 #endif
2042       });
2043   py_module.def(
2044       "_can_use_mem_efficient_attention",
2045       [](const sdp::sdp_params& params, bool debug) {
2046 #ifdef USE_CUDA
2047         return sdp::can_use_mem_efficient_attention(params, debug);
2048 #else
2049         return false;
2050 #endif
2051       });
2052 
2053   py::enum_<at::LinalgBackend>(py_module, "_LinalgBackend")
2054       .value("Default", at::LinalgBackend::Default)
2055       .value("Cusolver", at::LinalgBackend::Cusolver)
2056       .value("Magma", at::LinalgBackend::Magma);
2057 
2058   py_module.def("_set_linalg_preferred_backend", [](at::LinalgBackend b) {
2059     at::globalContext().setLinalgPreferredBackend(b);
2060   });
2061   py_module.def("_get_linalg_preferred_backend", []() {
2062     return at::globalContext().linalgPreferredBackend();
2063   });
2064 
2065   py::enum_<at::BlasBackend>(py_module, "_BlasBackend")
2066       .value("Cublas", at::BlasBackend::Cublas)
2067       .value("Cublaslt", at::BlasBackend::Cublaslt);
2068 
2069   py_module.def("_set_blas_preferred_backend", [](at::BlasBackend b) {
2070     at::globalContext().setBlasPreferredBackend(b);
2071   });
2072   py_module.def("_get_blas_preferred_backend", []() {
2073     return at::globalContext().blasPreferredBackend();
2074   });
2075 
2076   py_module.def(
2077       "_construct_storage_from_data_pointer",
2078       [](int64_t data_ptr, c10::Device device, size_t size_bytes) {
2079         return c10::Storage(
2080             c10::Storage::use_byte_size_t(),
2081             size_bytes,
2082             // NOLINTNEXTLINE(performance-no-int-to-ptr)
2083             at::DataPtr(reinterpret_cast<void*>(data_ptr), device));
2084       });
2085 
2086   py_module.def(
2087       "_stash_obj_in_tls", [](const std::string& key, py::handle arg) {
2088         at::impl::ThreadLocalPythonObjects::get_state().set(
2089             key,
2090             std::make_shared<c10::SafePyObject>(arg.ptr(), getPyInterpreter()));
2091       });
2092 
2093   py_module.def("_get_obj_in_tls", [](const std::string& key) -> py::handle {
2094     auto safe_pyobject =
2095         at::impl::ThreadLocalPythonObjects::get_state().get(key);
2096     auto obj = safe_pyobject->ptr(getPyInterpreter());
2097     return py::handle(obj);
2098   });
2099 
2100   py_module.def("_is_key_in_tls", [](const std::string& key) -> bool {
2101     return at::impl::ThreadLocalPythonObjects::get_state().contains(key);
2102   });
2103 
2104   py_module.def("_accelerator_hooks_device_count", []() {
2105     auto device_type = at::getAccelerator();
2106     if (device_type.has_value()) {
2107       return at::globalContext()
2108           .getAcceleratorHooksInterface(device_type.value())
2109           .deviceCount();
2110     }
2111     return c10::DeviceIndex(-1);
2112   });
2113 
2114   py_module.def(
2115       "_accelerator_hooks_set_current_device",
2116       [](c10::DeviceIndex device_index) {
2117         auto device_type = at::getAccelerator();
2118         if (device_type.has_value()) {
2119           at::globalContext()
2120               .getAcceleratorHooksInterface(device_type.value())
2121               .setCurrentDevice(device_index);
2122         }
2123       });
2124 
2125   py_module.def("_accelerator_hooks_get_current_device", []() {
2126     auto device_type = at::getAccelerator();
2127     if (device_type.has_value()) {
2128       return at::globalContext()
2129           .getAcceleratorHooksInterface(device_type.value())
2130           .getCurrentDevice();
2131     }
2132     return c10::DeviceIndex(-1);
2133   });
2134 
2135   py_module.def(
2136       "_accelerator_hooks_exchange_device", [](c10::DeviceIndex device_index) {
2137         auto device_type = at::getAccelerator();
2138         if (device_type.has_value()) {
2139           return at::globalContext()
2140               .getAcceleratorHooksInterface(device_type.value())
2141               .exchangeDevice(device_index);
2142         }
2143         return c10::DeviceIndex(-1);
2144       });
2145 
2146   py_module.def(
2147       "_accelerator_hooks_maybe_exchange_device",
2148       [](c10::DeviceIndex device_index) {
2149         auto device_type = at::getAccelerator();
2150         if (device_type.has_value()) {
2151           return at::globalContext()
2152               .getAcceleratorHooksInterface(device_type.value())
2153               .maybeExchangeDevice(device_index);
2154         }
2155         return c10::DeviceIndex(-1);
2156       });
2157 
2158   py_module.def(
2159       "_get_accelerator",
2160       [](std::optional<bool> check = std::nullopt) {
2161         return c10::Device(
2162             at::getAccelerator(check.value_or(false))
2163                 .value_or(c10::DeviceType::CPU),
2164             -1);
2165       },
2166       py::arg("check") = nullptr);
2167 
2168 #ifdef USE_CUDA
2169   PyObject* has_cuda = Py_True;
2170 #else
2171   PyObject* has_cuda = Py_False;
2172 #endif
2173 
2174 #ifdef USE_MPS
2175   PyObject* has_mps = Py_True;
2176 #else
2177   PyObject* has_mps = Py_False;
2178 #endif
2179 
2180 #ifdef USE_XPU
2181   PyObject* has_xpu = Py_True;
2182 #else
2183   PyObject* has_xpu = Py_False;
2184 #endif
2185 
2186   ASSERT_TRUE(set_module_attr("_has_cuda", has_cuda));
2187   ASSERT_TRUE(
2188       set_module_attr("_has_magma", at::hasMAGMA() ? Py_True : Py_False));
2189   ASSERT_TRUE(set_module_attr("_has_mps", has_mps));
2190   ASSERT_TRUE(set_module_attr("_has_xpu", has_xpu));
2191   ASSERT_TRUE(
2192       set_module_attr("_has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
2193 
2194 #ifdef _GLIBCXX_USE_CXX11_ABI
2195   ASSERT_TRUE(set_module_attr(
2196       "_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False));
2197 #else
2198   ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", Py_False));
2199 #endif
2200 
2201 // See note [Pybind11 ABI constants]
2202 #define SET_STR_DEFINE(name) \
2203   ASSERT_TRUE(set_module_attr("_" #name, THPUtils_packString(name)))
2204 
2205 #ifdef PYBIND11_COMPILER_TYPE
2206   SET_STR_DEFINE(PYBIND11_COMPILER_TYPE);
2207 #else
2208   ASSERT_TRUE(
2209       set_module_attr("_" C10_STRINGIZE(PYBIND11_COMPILER_TYPE), Py_None));
2210 #endif
2211 
2212 #ifdef PYBIND11_STDLIB
2213   SET_STR_DEFINE(PYBIND11_STDLIB);
2214 #else
2215   ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_STDLIB), Py_None));
2216 #endif
2217 
2218 #ifdef PYBIND11_BUILD_ABI
2219   SET_STR_DEFINE(PYBIND11_BUILD_ABI);
2220 #else
2221   ASSERT_TRUE(set_module_attr("_" C10_STRINGIZE(PYBIND11_BUILD_ABI), Py_None));
2222 #endif
2223 #undef SET_STR_DEFINE
2224 
2225   py_module.def(
2226       "_set_conj", [](const at::Tensor& x, bool conj) { x._set_conj(conj); });
2227   py_module.def(
2228       "_set_neg", [](const at::Tensor& x, bool neg) { x._set_neg(neg); });
2229   py_module.def("_get_tensor_metadata", &torch::jit::getTensorMetadata);
2230   py_module.def(
2231       "_set_tensor_metadata",
2232       static_cast<void (*)(
2233           const at::Tensor&, std::unordered_map<std::string, bool>)>(
2234           torch::jit::setTensorMetadata));
2235   py_module.def("_dispatch_key_set", [](const at::Tensor& x) {
2236     return toString(x.key_set());
2237   });
2238   py_module.def(
2239       "_has_storage", [](const at::Tensor& x) { return x.has_storage(); });
2240 
2241   py_module.def("_set_meta_in_tls_dispatch_include", [](bool meta_in_tls) {
2242     auto local_keyset = c10::impl::tls_local_dispatch_key_set();
2243     c10::DispatchKeySet key_set({at::DispatchKey::Meta});
2244     if (meta_in_tls) {
2245       local_keyset.included_ = local_keyset.included_ | key_set;
2246     } else {
2247       local_keyset.included_ =
2248           local_keyset.included_.remove_backend(c10::BackendComponent::MetaBit);
2249     }
2250     c10::impl::_force_tls_local_dispatch_key_set(local_keyset);
2251   });
2252 
2253   py_module.def("_meta_in_tls_dispatch_include", []() {
2254     auto local_keyset = c10::impl::tls_local_dispatch_key_set();
2255     return local_keyset.included_.has_backend(c10::BackendComponent::MetaBit);
2256   });
2257 
2258   py_module.def("_dump_local_tls_set", []() {
2259     auto local_keyset = c10::impl::tls_local_dispatch_key_set();
2260     std::cout << "Included: " << toString(local_keyset.included_) << "\n";
2261     std::cout << "Excluded: " << toString(local_keyset.excluded_) << "\n";
2262   });
2263 
2264   py_module.def(
2265       "_should_allow_numbers_as_tensors", [](const std::string& name) {
2266         return torch::should_allow_numbers_as_tensors(name);
2267       });
2268 
2269   py_module.def(
2270       "_group_tensors_by_device_and_dtype",
2271       [](const std::vector<std::vector<std::optional<at::Tensor>>>&
2272              nested_tensorlist,
2273          const bool with_indices) {
2274         return at::native::_group_tensors_by_first_tensors_device_and_dtype(
2275             nested_tensorlist, with_indices);
2276       });
2277 
2278   py_module.def(
2279       "_storage_address",
2280       [](const at::Tensor& tensor) {
2281         return reinterpret_cast<std::intptr_t>(
2282             tensor.storage().unsafeGetStorageImpl());
2283       },
2284       "Gets the memory address of the Tensor's StorageImpl.");
2285 
2286   py_module.def(
2287       "_data_address",
2288       [](const at::Tensor& tensor) {
2289         return reinterpret_cast<std::intptr_t>(tensor.storage().data());
2290       },
2291       "Gets the memory address of the Tensor's data pointer.");
2292 
2293   py_module.def(
2294       "_is_cow_tensor",
2295       [](const at::Tensor& tensor) {
2296         return c10::impl::cow::is_cow_data_ptr(tensor.storage().data_ptr());
2297       },
2298       "Checks if a tensor's data pointer is COW");
2299 
2300   py_module.def(
2301       "_get_cudnn_batch_norm_reserve_space_size",
2302       [](const at::Tensor& input, bool training) {
2303 #ifdef USE_CUDA
2304         return at::native::_get_cudnn_batch_norm_reserve_space_size(
2305             input, training);
2306 #else
2307         TORCH_CHECK(false, "PyTorch was not built with cuda");
2308 #endif
2309       },
2310       py::arg("input"),
2311       py::arg("training"));
2312 
2313   py::enum_<at::native::BatchNormBackend>(py_module, "_BatchNormBackend")
2314       .value("Native", at::native::BatchNormBackend::Native)
2315       .value("Cudnn", at::native::BatchNormBackend::Cudnn)
2316       .value("Miopen", at::native::BatchNormBackend::Miopen);
2317 
2318   py_module.def(
2319       "_select_batch_norm_backend",
2320       [](const at::Tensor& input,
2321          const at::Tensor& weight,
2322          const at::Tensor& bias,
2323          const at::Tensor& running_mean,
2324          const at::Tensor& running_var,
2325          bool training,
2326          double eps) {
2327         return at::native::_select_batch_norm_backend(
2328             input, weight, bias, running_mean, running_var, training, eps);
2329       },
2330       py::arg("input"),
2331       py::arg("weight"),
2332       py::arg("bias"),
2333       py::arg("running_mean"),
2334       py::arg("running_var"),
2335       py::arg("training"),
2336       py::arg("eps"));
2337 
2338   const auto& defaultGenerator = at::detail::getDefaultCPUGenerator();
2339   THPDefaultCPUGenerator =
2340       (THPGenerator*)THPGenerator_initDefaultGenerator(defaultGenerator);
2341   // This reference is meant to be given away, so no need to incref here.
2342   ASSERT_TRUE(set_module_attr(
2343       "default_generator",
2344       (PyObject*)THPDefaultCPUGenerator,
2345       /* incref= */ false));
2346   ASSERT_TRUE(set_module_attr(
2347       "DisableTorchFunctionSubclass",
2348       (PyObject*)THPModule_DisableTorchFunctionSubclassType(),
2349       /* incref= */ false));
2350   ASSERT_TRUE(set_module_attr(
2351       "DisableTorchFunction",
2352       (PyObject*)THPModule_DisableTorchFunctionType(),
2353       /* incref= */ false));
2354   torch::set_disabled_torch_function_impl(
2355       PyObject_GetAttrString(module, "_disabled_torch_function_impl"));
2356   ASSERT_TRUE(torch::disabled_torch_function_impl() != nullptr);
2357   torch::set_disabled_torch_dispatch_impl(
2358       PyObject_GetAttrString(module, "_disabled_torch_dispatch_impl"));
2359   ASSERT_TRUE(torch::disabled_torch_dispatch_impl() != nullptr);
2360   return module;
2361   END_HANDLE_TH_ERRORS
2362 }
2363 
2364 // Checks that the _C shared library isn't initialized multiple times. This
2365 // can happen if the same csrc files are compiled into multiple shared
2366 // libraries.
pytorch_duplicate_guard()2367 inline void pytorch_duplicate_guard() {
2368   static int initialized = 0;
2369   if (initialized) {
2370     fmt::print(stderr, "pytorch: _C shared library re-initialized\n");
2371     abort();
2372   }
2373   initialized = 1;
2374   ;
2375 }
2376 
2377 struct call_duplicate_guard {
call_duplicate_guardcall_duplicate_guard2378   call_duplicate_guard() {
2379     pytorch_duplicate_guard();
2380   }
2381 };
2382 
2383 static call_duplicate_guard _call_duplicate_guard;
2384