xref: /aosp_15_r20/external/pytorch/torch/csrc/autograd/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 
3 #include <ATen/PythonTorchFunctionTLS.h>
4 #include <ATen/SavedTensorHooks.h>
5 #include <ATen/SequenceNumber.h>
6 #include <ATen/autocast_mode.h>
7 #include <ATen/core/PythonFallbackKernel.h>
8 #include <ATen/record_function.h>
9 #include <c10/core/DeviceType.h>
10 #include <c10/core/InferenceMode.h>
11 #include <c10/core/ScalarType.h>
12 #include <c10/core/impl/PythonDispatcherTLS.h>
13 #include <torch/csrc/Exceptions.h>
14 #include <torch/csrc/autograd/VariableTypeUtils.h>
15 #include <torch/csrc/autograd/autograd.h>
16 #include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
17 #include <torch/csrc/autograd/function.h>
18 #include <torch/csrc/autograd/grad_mode.h>
19 #include <torch/csrc/autograd/input_metadata.h>
20 #include <torch/csrc/autograd/profiler.h>
21 #include <torch/csrc/autograd/profiler_python.h>
22 #include <torch/csrc/autograd/python_function.h>
23 #include <torch/csrc/autograd/python_saved_variable_hooks.h>
24 #include <torch/csrc/autograd/python_variable.h>
25 #include <torch/csrc/autograd/record_function_ops.h>
26 #include <torch/csrc/autograd/saved_variable.h>
27 #include <torch/csrc/autograd/utils/python_arg_parsing.h>
28 #include <torch/csrc/autograd/utils/wrap_outputs.h>
29 #include <torch/csrc/jit/python/pybind_utils.h>
30 #include <torch/csrc/profiler/collection.h>
31 #include <torch/csrc/profiler/kineto_shim.h>
32 #include <torch/csrc/utils.h>
33 #include <torch/csrc/utils/disable_torch_function.h>
34 #include <torch/csrc/utils/pybind.h>
35 #include <torch/csrc/utils/pycfunction_helpers.h>
36 #include <torch/csrc/utils/python_raii.h>
37 #include <torch/csrc/utils/python_torch_function_mode.h>
38 
39 #include <set>
40 #include <unordered_set>
41 #include <utility>
42 
43 using torch::impl::py_context_manager;
44 using torch::impl::py_context_manager_DEPRECATED;
45 
46 namespace {
47 
48 struct DisableFuncTorch {
DisableFuncTorch__anon67aec8d30111::DisableFuncTorch49   DisableFuncTorch()
50       : front_guard_(c10::DispatchKey::FuncTorchDynamicLayerFrontMode),
51         back_guard_(c10::DispatchKey::FuncTorchDynamicLayerBackMode) {}
52   c10::impl::ExcludeDispatchKeyGuard front_guard_;
53   c10::impl::ExcludeDispatchKeyGuard back_guard_;
54 };
55 
56 struct DisableAutocast {
57   c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset};
58 };
59 
60 struct EnableTorchFunction {
EnableTorchFunction__anon67aec8d30111::EnableTorchFunction61   EnableTorchFunction()
62       : old_(at::impl::PythonTorchFunctionTLS::get_disabled_state()) {
63     at::impl::PythonTorchFunctionTLS::set_disabled_state(
64         at::impl::TorchFunctionDisabledState::ENABLED);
65   }
~EnableTorchFunction__anon67aec8d30111::EnableTorchFunction66   ~EnableTorchFunction() {
67     at::impl::PythonTorchFunctionTLS::set_disabled_state(old_);
68   }
69   at::impl::TorchFunctionDisabledState old_;
70 };
71 
72 struct EnablePythonDispatcher {
EnablePythonDispatcher__anon67aec8d30111::EnablePythonDispatcher73   EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) {
74     c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter());
75   }
~EnablePythonDispatcher__anon67aec8d30111::EnablePythonDispatcher76   ~EnablePythonDispatcher() {
77     c10::impl::PythonDispatcherTLS::set_state(old_);
78   }
79   c10::impl::PyInterpreter* old_;
80 };
81 
82 struct EnablePreDispatch {
EnablePreDispatch__anon67aec8d30111::EnablePreDispatch83   EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
84   c10::impl::IncludeDispatchKeyGuard guard_;
85 };
86 
87 } // namespace
88 
THPAutograd_initExtension(PyObject * _unused,PyObject * unused)89 PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
90   using namespace torch::autograd::profiler;
91   using namespace torch::profiler::impl;
92   auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch._tensor"));
93   if (!tensor_module)
94     return nullptr;
95 
96   // NOTE: "leaks" THPVariableClass
97   THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor");
98   if (!THPVariableClass)
99     return nullptr;
100 
101   auto autograd_module = THPObjectPtr(PyImport_ImportModule("torch.autograd"));
102   if (!autograd_module)
103     return nullptr;
104 
105   // NOTE: "leaks" Function
106   THPFunctionClass = PyObject_GetAttrString(autograd_module, "Function");
107   if (!THPFunctionClass)
108     return nullptr;
109 
110   // NOTE: "leaks" GradientEdge
111   auto autograd_graph_mod =
112       THPObjectPtr(PyImport_ImportModule("torch.autograd.graph"));
113   THPGradientEdgeClass =
114       PyObject_GetAttrString(autograd_graph_mod, "GradientEdge");
115   if (!THPGradientEdgeClass)
116     return nullptr;
117 
118   auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
119   if (!torch_C_module)
120     return nullptr;
121   auto _C_m = py::handle(torch_C_module).cast<py::module>();
122   auto m = _C_m.def_submodule("_autograd", "autograd bindings");
123 
124   auto parameter_module =
125       THPObjectPtr(PyImport_ImportModule("torch.nn.parameter"));
126   if (!parameter_module)
127     return nullptr;
128 
129   // NOTE: "leaks" ParameterClass
130   ParameterClass = PyObject_GetAttrString(parameter_module, "Parameter");
131   if (!ParameterClass)
132     return nullptr;
133 
134   py::class_<LegacyEvent>(m, "ProfilerEvent")
135       .def("kind", &LegacyEvent::kindStr)
136       .def("name", [](const LegacyEvent& e) { return e.name(); })
137       .def("thread_id", &LegacyEvent::threadId)
138       .def("fwd_thread_id", &LegacyEvent::fwdThreadId)
139       .def("device", &LegacyEvent::device)
140       .def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs)
141       .def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs)
142       .def("has_cuda", &LegacyEvent::hasCuda)
143       .def("shapes", &LegacyEvent::shapes)
144       .def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage)
145       .def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage)
146       .def("handle", &LegacyEvent::handle)
147       .def("node_id", &LegacyEvent::nodeId)
148       .def("is_remote", &LegacyEvent::isRemote)
149       .def("sequence_nr", &LegacyEvent::sequenceNr)
150       .def("stack", &LegacyEvent::stack)
151       .def("scope", &LegacyEvent::scope)
152       .def("correlation_id", &LegacyEvent::correlationId)
153       .def("start_us", &LegacyEvent::cpuUs)
154       .def("flops", &LegacyEvent::flops)
155       .def("is_async", &LegacyEvent::isAsync);
156 
157   py::enum_<c10::DeviceType>(m, "DeviceType")
158       .value("CPU", c10::DeviceType::CPU)
159       .value("CUDA", c10::DeviceType::CUDA)
160       .value("MKLDNN", c10::DeviceType::MKLDNN)
161       .value("OPENGL", c10::DeviceType::OPENGL)
162       .value("OPENCL", c10::DeviceType::OPENCL)
163       .value("IDEEP", c10::DeviceType::IDEEP)
164       .value("HIP", c10::DeviceType::HIP)
165       .value("FPGA", c10::DeviceType::FPGA)
166       .value("MAIA", c10::DeviceType::MAIA)
167       .value("XLA", c10::DeviceType::XLA)
168       .value("Vulkan", c10::DeviceType::Vulkan)
169       .value("Metal", c10::DeviceType::Metal)
170       .value("XPU", c10::DeviceType::XPU)
171       .value("MPS", c10::DeviceType::MPS)
172       .value("MTIA", c10::DeviceType::MTIA)
173       .value("Meta", c10::DeviceType::Meta)
174       .value("HPU", c10::DeviceType::HPU)
175       .value("VE", c10::DeviceType::VE)
176       .value("Lazy", c10::DeviceType::Lazy)
177       .value("IPU", c10::DeviceType::IPU)
178       .value("PrivateUse1", c10::DeviceType::PrivateUse1);
179 
180   using torch::autograd::CreationMeta;
181   py::enum_<CreationMeta>(m, "CreationMeta")
182       .value("DEFAULT", CreationMeta::DEFAULT)
183       .value("IN_CUSTOM_FUNCTION", CreationMeta::IN_CUSTOM_FUNCTION)
184       .value("MULTI_OUTPUT_NODE", CreationMeta::MULTI_OUTPUT_NODE)
185       .value("NO_GRAD_MODE", CreationMeta::NO_GRAD_MODE)
186       .value("INFERENCE_MODE", CreationMeta::INFERENCE_MODE);
187 
188   py::class_<torch::autograd::InputMetadata>(m, "_InputMetadata")
189       .def_property_readonly(
190           "dtype",
191           [](const torch::autograd::InputMetadata& m) {
192             PyObject* raw_obj =
193                 (PyObject*)torch::getTHPDtype(m.dtype().toScalarType());
194             return py::reinterpret_borrow<py::object>(raw_obj);
195           })
196       .def_property_readonly("device", &torch::autograd::InputMetadata::device)
197       .def_property_readonly(
198           "shape", &torch::autograd::InputMetadata::shape_as_dim_vector)
199       .def_property_readonly(
200           "is_nested_tensor", &torch::autograd::InputMetadata::is_nested_tensor)
201       .def_property_readonly(
202           "is_cpp_nested_tensor",
203           &torch::autograd::InputMetadata::is_cpp_nested_tensor);
204 
205   py::class_<KinetoEvent>(m, "_KinetoEvent")
206       // name of the event
207       .def("name", [](const KinetoEvent& e) { return e.name(); })
208       // PyTorch thread id of the start callback
209       .def(
210           "start_thread_id",
211           [](const KinetoEvent& e) { return e.startThreadId(); })
212       // PyTorch thread id of the end callback
213       .def(
214           "end_thread_id", [](const KinetoEvent& e) { return e.endThreadId(); })
215       // for events of scope BACKWARD_FUNCTION - PyTorch thread id
216       // of the corresponding forward op
217       .def(
218           "fwd_thread_id", [](const KinetoEvent& e) { return e.fwdThreadId(); })
219       // together with fwd_thread_id, used to uniquely identify
220       // the forward op
221       .def("sequence_nr", [](const KinetoEvent& e) { return e.sequenceNr(); })
222       // absolute start time (since unix epoch) in ns
223       .def("start_ns", [](const KinetoEvent& e) { return e.startNs(); })
224       // absolute end time (since unix epoch) in ns
225       .def("end_ns", [](const KinetoEvent& e) { return e.endNs(); })
226       // duration in ns
227       .def("duration_ns", [](const KinetoEvent& e) { return e.durationNs(); })
228       // used for correlation between high-level PyTorch events
229       // and low-level device events
230       .def(
231           "correlation_id",
232           [](const KinetoEvent& e) { return e.correlationId(); })
233       // shapes of input tensors
234       .def("shapes", [](const KinetoEvent& e) { return e.shapes().vec(); })
235       .def("dtypes", [](const KinetoEvent& e) { return e.dtypes().vec(); })
236       .def(
237           "concrete_inputs",
238           [](const KinetoEvent& e) {
239             std::vector<py::object> as_pyobj;
240             std::transform(
241                 e.concreteInputs().begin(),
242                 e.concreteInputs().end(),
243                 std::back_inserter(as_pyobj),
244                 [](const c10::IValue& val) {
245                   return torch::jit::toPyObject(val);
246                 });
247             return as_pyobj;
248           })
249       .def(
250           "kwinputs",
251           [](const KinetoEvent& e) {
252             std::unordered_map<std::string, py::object> inputs;
253             for (const auto& [key, value] : e.kwinputs()) {
254               inputs[key] = torch::jit::toPyObject(value);
255             }
256             return inputs;
257           })
258       // stack traces of the PyTorch CPU events
259       .def("stack", [](const KinetoEvent& e) { return e.stack().vec(); })
260       // type of the RecordFunction that generated a PyTorch CPU event
261       // (op, torchscript function, user label, etc)
262       .def("scope", [](const KinetoEvent& e) { return e.scope(); })
263       // device number, for CPU - process id
264       .def("device_index", [](const KinetoEvent& e) { return e.deviceIndex(); })
265       // for CUDA - stream id, for CPU - start thread id
266       .def(
267           "device_resource_id",
268           [](const KinetoEvent& e) { return e.deviceResourceId(); })
269       // device type
270       .def("device_type", [](const KinetoEvent& e) { return e.deviceType(); })
271       // correlation id of a linked event
272       .def(
273           "linked_correlation_id",
274           [](const KinetoEvent& e) { return e.linkedCorrelationId(); })
275       // compute flops
276       .def("flops", [](const KinetoEvent& e) { return e.flops(); })
277       // Whether this is async event or not
278       .def("is_async", [](const KinetoEvent& e) { return e.isAsync(); })
279       .def("cuda_elapsed_us", &KinetoEvent::cudaElapsedUs)
280       .def("privateuse1_elapsed_us", &KinetoEvent::privateuse1ElapsedUs)
281       .def(
282           "is_user_annotation",
283           [](const KinetoEvent& e) {
284             return e.activityType() ==
285                 (uint8_t)libkineto::ActivityType::USER_ANNOTATION ||
286                 e.activityType() ==
287                 (uint8_t)libkineto::ActivityType::GPU_USER_ANNOTATION;
288           })
289       .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); });
290 
291   m.def("_soft_assert_raises", &setSoftAssertRaises);
292   m.def("_get_sequence_nr", &at::sequence_number::peek);
293 
294   py::class_<ProfilerResult>(m, "_ProfilerResult")
295       .def("trace_start_ns", &ProfilerResult::trace_start_ns)
296       .def("events", &ProfilerResult::events)
297       .def("experimental_event_tree", &ProfilerResult::event_tree)
298 #ifdef USE_KINETO
299       .def("save", &ProfilerResult::save)
300 #endif // USE_KINETO
301       ;
302 
303   m.def(
304       "_enable_profiler",
305       &enableProfiler,
306       py::arg("config"),
307       py::arg("activities"),
308       py::arg("scopes") = std::unordered_set<at::RecordScope>());
309   m.def("_disable_profiler", disableProfiler);
310   m.def(
311       "_prepare_profiler",
312       prepareProfiler,
313       py::call_guard<py::gil_scoped_release>());
314   m.def(
315       "_toggle_collection_dynamic",
316       toggleCollectionDynamic,
317       py::call_guard<py::gil_scoped_release>());
318   m.def("_add_metadata_json", addMetadataJson); // Only if `USE_KINETO` is set
319   m.def("_kineto_step", profilerStep); // Only if `USE_KINETO` is set
320   m.def("kineto_available", []() { return torch::profiler::kKinetoAvailable; });
321 
322   // NOTICE: These record functions are not torch operators and may not show up
323   // in TorchScript tracing, FX transforms, or operator serialization. For these
324   // use cases, please use `torch.profiler.record_function`.
325   // Creates a new profiling scope using RecordFunction and invokes its starting
326   // callbacks.
327   m.def(
328       "_record_function_with_args_enter",
329       [](const std::string& name, const py::args& args) {
330         using torch::autograd::profiler::PythonRecordFunction;
331         auto python_rec = c10::make_intrusive<PythonRecordFunction>(
332             at::RecordScope::USER_SCOPE);
333         auto* rec = &python_rec->record;
334         if (rec->isActive()) {
335           if (rec->needsInputs()) {
336             auto iv_inputs = std::vector<c10::IValue>();
337             for (const auto& arg : args) {
338               iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg));
339             }
340             rec->before(
341                 name,
342                 c10::ArrayRef<const c10::IValue>(
343                     iv_inputs.data(), iv_inputs.size()));
344           } else {
345             rec->before(name);
346           }
347         }
348         return torch::jit::toPyObject(std::move(python_rec));
349       });
350 
351   // Ends the profiling scope created with record_function_with_param_enter.
352   m.def("_record_function_with_args_exit", [](const py::object& obj) {
353     using torch::autograd::profiler::PythonRecordFunction;
354     auto python_record = torch::jit::toCustomClass<PythonRecordFunction>(obj);
355 
356     // We don't actually need to do anything with handle just need to persist
357     // the lifetime until now.
358     python_record->record.end();
359   });
360 
361   m.def("_supported_activities", []() {
362     std::set<torch::profiler::impl::ActivityType> activities{
363         torch::profiler::impl::ActivityType::CPU};
364 #if defined(USE_KINETO) && \
365     (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
366     if (at::hasMTIA()) {
367       activities.insert(torch::profiler::impl::ActivityType::MTIA);
368     }
369     if (at::getNumGPUs() > 0) {
370       activities.insert(torch::profiler::impl::ActivityType::CUDA);
371     }
372 #elif defined(USE_KINETO)
373     if (at::hasXPU()) {
374       activities.insert(torch::profiler::impl::ActivityType::XPU);
375     }
376     if (at::hasMTIA()) {
377       activities.insert(torch::profiler::impl::ActivityType::MTIA);
378     }
379     if (c10::get_privateuse1_backend() != "privateuseone") {
380       activities.insert(torch::profiler::impl::ActivityType::PrivateUse1);
381     }
382 #endif
383     return activities;
384   });
385 
386   m.def("_unsafe_set_version_counter", [](const at::Tensor& t, int64_t i) {
387     auto vc = torch::autograd::impl::version_counter(t);
388     vc.set_version(i);
389   });
390 
391   m.def("_enable_profiler_legacy", enableProfilerLegacy);
392   py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
393       .def(py::init<bool, bool>());
394   m.def(
395       "_disable_profiler_legacy",
396       disableProfilerLegacy,
397       py::arg("profiler_disable_options") = ProfilerDisableOptions());
398   m.def("_profiler_enabled", profilerEnabled);
399   m.def("_profiler_type", torch::profiler::impl::profilerType);
400   m.def("_enable_record_function", [](bool enable) {
401     at::enableRecordFunction(enable);
402   });
403   m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
404     auto cb =
405         at::RecordFunctionCallback(nullptr).needsInputs(true).samplingProb(
406             sampling_prob);
407     if (is_global) {
408       at::addGlobalCallback(cb);
409     } else {
410       at::addThreadLocalCallback(cb);
411     }
412   });
413   m.def("_clear_callbacks", []() { at::clearCallbacks(); });
414   m.def(
415       "_saved_tensors_hooks_is_enabled",
416       at::SavedTensorDefaultHooks::is_enabled);
417   m.def("_saved_tensors_hooks_enable", at::SavedTensorDefaultHooks::enable);
418   m.def("_saved_tensors_hooks_disable", at::SavedTensorDefaultHooks::disable);
419   m.def(
420       "_saved_tensors_hooks_set_tracing",
421       at::SavedTensorDefaultHooks::set_tracing);
422   m.def(
423       "_saved_tensors_hooks_get_disabled_error_message",
424       at::SavedTensorDefaultHooks::get_disabled_error_message);
425   m.def(
426       "_push_saved_tensors_default_hooks",
427       [](py::function& pack_hook, py::function& unpack_hook) {
428         torch::autograd::PyDefaultSavedVariableHooks::push_hooks(
429             pack_hook, unpack_hook);
430       });
431   m.def("_pop_saved_tensors_default_hooks", []() {
432     torch::autograd::PyDefaultSavedVariableHooks::pop_hooks();
433   });
434 
435   m.def("_get_creation_meta", [](const at::Tensor& t) {
436     auto* meta = torch::autograd::impl::get_view_autograd_meta(t);
437     TORCH_CHECK(meta != nullptr);
438     return meta->get_creation_meta();
439   });
440 
441   m.def(
442       "_set_creation_meta",
443       [](const at::Tensor& t, CreationMeta new_creation_meta) {
444         auto* meta = torch::autograd::impl::get_view_autograd_meta(t);
445         TORCH_CHECK(meta != nullptr);
446         meta->set_creation_meta(new_creation_meta);
447       });
448 
449   m.def("_get_current_graph_task_keep_graph", []() {
450     return torch::autograd::get_current_graph_task_keep_graph();
451   });
452 
453   m.def(
454       "_get_data_attr", [](const at::Tensor& t) { return t.variable_data(); });
455 
456   _C_m.def(
457       "_register_py_class_for_device",
458       [](const std::string& device, py::object python_type_class) {
459         auto cls = python_type_class.ptr();
460         registerPythonTensorClass(device, cls);
461       });
462   _C_m.def("_set_autograd_fallback_mode", [](const std::string& mode) {
463     if (mode == "nothing") {
464       torch::autograd::setAutogradFallbackMode(
465           torch::autograd::AutogradFallbackMode::Nothing);
466       return;
467     }
468     if (mode == "warn") {
469       torch::autograd::setAutogradFallbackMode(
470           torch::autograd::AutogradFallbackMode::Warn);
471       return;
472     }
473     if (mode == "error") {
474       torch::autograd::setAutogradFallbackMode(
475           torch::autograd::AutogradFallbackMode::Error);
476       return;
477     }
478     TORCH_INTERNAL_ASSERT(false, "Unsupported AutogradFallbackMode: ", mode);
479   });
480   _C_m.def("_get_autograd_fallback_mode", []() {
481     auto mode = torch::autograd::getAutogradFallbackMode();
482     switch (mode) {
483       case torch::autograd::AutogradFallbackMode::Nothing:
484         return "nothing";
485       case torch::autograd::AutogradFallbackMode::Warn:
486         return "warn";
487       case torch::autograd::AutogradFallbackMode::Error:
488         return "error";
489       default:
490         TORCH_INTERNAL_ASSERT(false, "Unsupported AutogradFallbackMode");
491     }
492   });
493 
494   _C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
495 
496   py_context_manager_DEPRECATED<c10::InferenceMode, bool>(
497       _C_m, "_InferenceMode");
498   py_context_manager<at::impl::RestorePythonTLSSnapshot>(
499       _C_m, "_RestorePythonTLSSnapshot");
500 
501   py_context_manager_DEPRECATED<torch::DisableTorchDispatch>(
502       _C_m, "_DisableTorchDispatch");
503   py_context_manager_DEPRECATED<EnableTorchFunction>(
504       _C_m, "_EnableTorchFunction");
505   py_context_manager_DEPRECATED<EnablePythonDispatcher>(
506       _C_m, "_EnablePythonDispatcher");
507   py_context_manager<c10::impl::DisablePythonDispatcher>(
508       _C_m, "_DisablePythonDispatcher");
509   py_context_manager<EnablePreDispatch>(_C_m, "_EnablePreDispatch");
510   py_context_manager_DEPRECATED<DisableFuncTorch>(_C_m, "_DisableFuncTorch");
511   py_context_manager<DisableAutocast>(_C_m, "_DisableAutocast");
512   py::class_<torch::autograd::SavedVariable>(std::move(m), "SavedTensor")
513       .def(py::init([]() -> torch::autograd::SavedVariable {
514         TORCH_CHECK(
515             false,
516             "Trying to create a SavedTensor object from Python is forbidden.");
517       }))
518       .def(
519           "register_hooks",
520           [](torch::autograd::SavedVariable& s,
521              py::function& pack_hook,
522              py::function& unpack_hook) {
523             // Because we use a py::object, pybind will increment the refcount
524             // of the hook functions for us
525             s.register_hooks(
526                 std::make_unique<torch::autograd::PySavedVariableHooks>(
527                     pack_hook, unpack_hook));
528           });
529 
530   torch::autograd::profiler::python_tracer::init();
531   Py_RETURN_TRUE;
532 }
533 
534 namespace torch::autograd {
535 
set_autocast_enabled(PyObject * _unused,PyObject * args,PyObject * kwargs)536 static PyObject* set_autocast_enabled(
537     PyObject* _unused,
538     PyObject* args,
539     PyObject* kwargs) {
540   HANDLE_TH_ERRORS
541   static PythonArgParser parser(
542       {"set_autocast_enabled(c10::string_view device_type, bool enabled)",
543        "set_autocast_enabled(bool enabled)"}); // this signature is depracated.
544   ParsedArgs<2> parsed_args;
545   auto r = parser.parse(args, kwargs, parsed_args);
546   // Set at::kCUDA as default value to prevent BC-breaking changes.
547   at::DeviceType device_type = at::kCUDA;
548   int enabled_id = 0;
549   if (r.idx == 0) {
550     device_type = at::Device(r.string(0)).type();
551     enabled_id = 1;
552   }
553   auto enabled = r.toBool(enabled_id);
554   at::autocast::set_autocast_enabled(device_type, enabled);
555   Py_RETURN_NONE;
556   END_HANDLE_TH_ERRORS
557 }
558 
is_autocast_enabled(PyObject * _unused,PyObject * args,PyObject * kwargs)559 static PyObject* is_autocast_enabled(
560     PyObject* _unused,
561     PyObject* args,
562     PyObject* kwargs) {
563   HANDLE_TH_ERRORS
564   static PythonArgParser parser(
565       {"is_autocast_enabled(c10::string_view device_type)",
566        "is_autocast_enabled()"}); // this signature is depracated.
567   ParsedArgs<1> parsed_args;
568   auto r = parser.parse(args, kwargs, parsed_args);
569   // Set at::kCUDA as default value to prevent BC-breaking changes.
570   at::DeviceType device_type = at::kCUDA;
571   if (r.idx == 0) {
572     device_type = at::Device(r.string(0)).type();
573   }
574   if (at::autocast::is_autocast_enabled(device_type)) {
575     Py_RETURN_TRUE;
576   } else {
577     Py_RETURN_FALSE;
578   }
579   END_HANDLE_TH_ERRORS
580 }
581 
get_autocast_dtype(PyObject * _unused,PyObject * args,PyObject * kwargs)582 static PyObject* get_autocast_dtype(
583     PyObject* _unused,
584     PyObject* args,
585     PyObject* kwargs) {
586   HANDLE_TH_ERRORS
587   static PythonArgParser parser(
588       {"get_autocast_dtype(c10::string_view device_type)"});
589   ParsedArgs<1> parsed_args;
590   auto r = parser.parse(args, kwargs, parsed_args);
591   auto device_type = at::Device(r.string(0)).type();
592   at::ScalarType current_dtype = at::autocast::get_autocast_dtype(device_type);
593   return utils::wrap(current_dtype);
594   END_HANDLE_TH_ERRORS
595 }
596 
set_autocast_dtype(PyObject * _unused,PyObject * args,PyObject * kwargs)597 static PyObject* set_autocast_dtype(
598     PyObject* _unused,
599     PyObject* args,
600     PyObject* kwargs) {
601   HANDLE_TH_ERRORS
602   static PythonArgParser parser(
603       {"set_autocast_dtype(c10::string_view device_type, ScalarType dtype)"});
604   ParsedArgs<2> parsed_args;
605   auto r = parser.parse(args, kwargs, parsed_args);
606   auto device_type = at::Device(r.string(0)).type();
607   auto dtype = r.scalartype(1);
608   at::autocast::set_autocast_dtype(device_type, dtype);
609   Py_RETURN_NONE;
610   END_HANDLE_TH_ERRORS
611 }
612 
is_any_autocast_enabled(PyObject * _unused,PyObject * arg)613 static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
614   HANDLE_TH_ERRORS
615   if (at::autocast::is_autocast_enabled(at::kCPU) ||
616       at::autocast::is_autocast_enabled(at::kCUDA) ||
617       at::autocast::is_autocast_enabled(at::kXPU) ||
618       at::autocast::is_autocast_enabled(at::kIPU) ||
619       at::autocast::is_autocast_enabled(at::kXLA) ||
620       at::autocast::is_autocast_enabled(at::kHPU) ||
621       at::autocast::is_autocast_enabled(at::kPrivateUse1)) {
622     Py_RETURN_TRUE;
623   } else {
624     Py_RETURN_FALSE;
625   }
626   END_HANDLE_TH_ERRORS
627 }
628 
is_autocast_available(PyObject * _unused,PyObject * args,PyObject * kwargs)629 static PyObject* is_autocast_available(
630     PyObject* _unused,
631     PyObject* args,
632     PyObject* kwargs) {
633   HANDLE_TH_ERRORS
634   static PythonArgParser parser(
635       {"_is_autocast_available(c10::string_view device_type)"});
636   ParsedArgs<1> parsed_args;
637   auto r = parser.parse(args, kwargs, parsed_args);
638   auto device_type = at::Device(r.string(0)).type();
639   if (at::autocast::is_autocast_available(device_type)) {
640     Py_RETURN_TRUE;
641   } else {
642     Py_RETURN_FALSE;
643   }
644   END_HANDLE_TH_ERRORS
645 }
646 
set_autocast_cpu_enabled(PyObject * _unused,PyObject * arg)647 static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
648   HANDLE_TH_ERRORS
649   TORCH_CHECK_TYPE(
650       PyBool_Check(arg),
651       "enabled must be a bool (got ",
652       Py_TYPE(arg)->tp_name,
653       ")");
654   TORCH_WARN_DEPRECATION(
655       "torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead.")
656   at::autocast::set_autocast_enabled(at::kCPU, arg == Py_True);
657   Py_RETURN_NONE;
658   END_HANDLE_TH_ERRORS
659 }
660 
is_autocast_cpu_enabled(PyObject * _unused,PyObject * arg)661 static PyObject* is_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
662   HANDLE_TH_ERRORS
663   TORCH_WARN_DEPRECATION(
664       "torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead.")
665   if (at::autocast::is_autocast_enabled(at::kCPU)) {
666     Py_RETURN_TRUE;
667   } else {
668     Py_RETURN_FALSE;
669   }
670   END_HANDLE_TH_ERRORS
671 }
672 
set_autocast_ipu_enabled(PyObject * _unused,PyObject * arg)673 static PyObject* set_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
674   HANDLE_TH_ERRORS
675   TORCH_CHECK_TYPE(
676       PyBool_Check(arg),
677       "enabled must be a bool (got ",
678       Py_TYPE(arg)->tp_name,
679       ")");
680   TORCH_WARN_DEPRECATION(
681       "torch.set_autocast_ipu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('ipu', enabled) instead.")
682   at::autocast::set_autocast_enabled(at::kIPU, arg == Py_True);
683   Py_RETURN_NONE;
684   END_HANDLE_TH_ERRORS
685 }
686 
is_autocast_ipu_enabled(PyObject * _unused,PyObject * arg)687 static PyObject* is_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
688   HANDLE_TH_ERRORS
689   TORCH_WARN_DEPRECATION(
690       "torch.is_autocast_ipu_enabled() is deprecated. Please use torch.is_autocast_enabled('ipu') instead.")
691   if (at::autocast::is_autocast_enabled(at::kIPU)) {
692     Py_RETURN_TRUE;
693   } else {
694     Py_RETURN_FALSE;
695   }
696   END_HANDLE_TH_ERRORS
697 }
698 
set_autocast_xla_enabled(PyObject * _unused,PyObject * arg)699 static PyObject* set_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
700   HANDLE_TH_ERRORS
701   TORCH_CHECK_TYPE(
702       PyBool_Check(arg),
703       "enabled must be a bool (got ",
704       Py_TYPE(arg)->tp_name,
705       ")");
706   TORCH_WARN_DEPRECATION(
707       "torch.set_autocast_xla_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('xla', enabled) instead.")
708   at::autocast::set_autocast_enabled(at::kXLA, arg == Py_True);
709   Py_RETURN_NONE;
710   END_HANDLE_TH_ERRORS
711 }
712 
is_autocast_xla_enabled(PyObject * _unused,PyObject * arg)713 static PyObject* is_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
714   HANDLE_TH_ERRORS
715   TORCH_WARN_DEPRECATION(
716       "torch.is_autocast_xla_enabled() is deprecated. Please use torch.is_autocast_enabled('xla') instead.")
717   if (at::autocast::is_autocast_enabled(at::kXLA)) {
718     Py_RETURN_TRUE;
719   } else {
720     Py_RETURN_FALSE;
721   }
722   END_HANDLE_TH_ERRORS
723 }
724 
set_autocast_gpu_dtype(PyObject * _unused,PyObject * arg)725 static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
726   HANDLE_TH_ERRORS
727   TORCH_CHECK_TYPE(
728       THPDtype_Check(arg),
729       "dtype must be a torch.dtype (got ",
730       Py_TYPE(arg)->tp_name,
731       ")");
732   TORCH_WARN_DEPRECATION(
733       "torch.set_autocast_gpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cuda', dtype) instead.")
734   at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
735   at::autocast::set_autocast_dtype(at::kCUDA, targetType);
736   Py_RETURN_NONE;
737   END_HANDLE_TH_ERRORS
738 }
739 
set_autocast_cpu_dtype(PyObject * _unused,PyObject * arg)740 static PyObject* set_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
741   HANDLE_TH_ERRORS
742   TORCH_CHECK_TYPE(
743       THPDtype_Check(arg),
744       "dtype must be a torch.dtype (got ",
745       Py_TYPE(arg)->tp_name,
746       ")");
747   TORCH_WARN_DEPRECATION(
748       "torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead.")
749   at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
750   at::autocast::set_autocast_dtype(at::kCPU, targetType);
751   Py_RETURN_NONE;
752   END_HANDLE_TH_ERRORS
753 }
754 
set_autocast_ipu_dtype(PyObject * _unused,PyObject * arg)755 static PyObject* set_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
756   HANDLE_TH_ERRORS
757   TORCH_CHECK_TYPE(
758       THPDtype_Check(arg),
759       "dtype must be a torch.dtype (got ",
760       Py_TYPE(arg)->tp_name,
761       ")");
762   TORCH_WARN_DEPRECATION(
763       "torch.set_autocast_ipu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('ipu', dtype) instead.")
764   at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
765   at::autocast::set_autocast_dtype(at::kIPU, targetType);
766   Py_RETURN_NONE;
767   END_HANDLE_TH_ERRORS
768 }
769 
set_autocast_xla_dtype(PyObject * _unused,PyObject * arg)770 static PyObject* set_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
771   HANDLE_TH_ERRORS
772   TORCH_CHECK_TYPE(
773       THPDtype_Check(arg),
774       "dtype must be a torch.dtype (got ",
775       Py_TYPE(arg)->tp_name,
776       ")");
777   TORCH_WARN_DEPRECATION(
778       "torch.set_autocast_xla_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('xla', dtype) instead.")
779   at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
780   at::autocast::set_autocast_dtype(at::kXLA, targetType);
781   Py_RETURN_NONE;
782   END_HANDLE_TH_ERRORS
783 }
784 
get_autocast_gpu_dtype(PyObject * _unused,PyObject * arg)785 static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
786   HANDLE_TH_ERRORS
787   TORCH_WARN_DEPRECATION(
788       "torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.")
789   at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCUDA);
790   return utils::wrap(current_dtype);
791   END_HANDLE_TH_ERRORS
792 }
793 
get_autocast_cpu_dtype(PyObject * _unused,PyObject * arg)794 static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
795   HANDLE_TH_ERRORS
796   TORCH_WARN_DEPRECATION(
797       "torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead.")
798   at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCPU);
799   return utils::wrap(current_dtype);
800   END_HANDLE_TH_ERRORS
801 }
802 
get_autocast_ipu_dtype(PyObject * _unused,PyObject * arg)803 static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
804   HANDLE_TH_ERRORS
805   TORCH_WARN_DEPRECATION(
806       "torch.get_autocast_ipu_dtype() is deprecated. Please use torch.get_autocast_dtype('ipu') instead.")
807   at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kIPU);
808   return utils::wrap(current_dtype);
809   END_HANDLE_TH_ERRORS
810 }
811 
get_autocast_xla_dtype(PyObject * _unused,PyObject * arg)812 static PyObject* get_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
813   HANDLE_TH_ERRORS
814   TORCH_WARN_DEPRECATION(
815       "torch.get_autocast_xla_dtype() is deprecated. Please use torch.get_autocast_dtype('xla') instead.")
816   at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kXLA);
817   return utils::wrap(current_dtype);
818   END_HANDLE_TH_ERRORS
819 }
820 
clear_autocast_cache(PyObject * _unused,PyObject * arg)821 static PyObject* clear_autocast_cache(PyObject* _unused, PyObject* arg) {
822   HANDLE_TH_ERRORS {
823     pybind11::gil_scoped_release no_gil;
824     at::autocast::clear_cache();
825   }
826   Py_RETURN_NONE;
827   END_HANDLE_TH_ERRORS
828 }
829 
autocast_increment_nesting(PyObject * _unused,PyObject * arg)830 static PyObject* autocast_increment_nesting(PyObject* _unused, PyObject* arg) {
831   HANDLE_TH_ERRORS
832   return THPUtils_packInt64(at::autocast::increment_nesting());
833   END_HANDLE_TH_ERRORS
834 }
835 
autocast_decrement_nesting(PyObject * _unused,PyObject * arg)836 static PyObject* autocast_decrement_nesting(PyObject* _unused, PyObject* arg) {
837   HANDLE_TH_ERRORS
838   return THPUtils_packInt64(at::autocast::decrement_nesting());
839   END_HANDLE_TH_ERRORS
840 }
841 
is_autocast_cache_enabled(PyObject * _unused,PyObject * arg)842 static PyObject* is_autocast_cache_enabled(PyObject* _unused, PyObject* arg) {
843   HANDLE_TH_ERRORS
844   if (at::autocast::is_autocast_cache_enabled()) {
845     Py_RETURN_TRUE;
846   } else {
847     Py_RETURN_FALSE;
848   }
849   END_HANDLE_TH_ERRORS
850 }
851 
set_autocast_cache_enabled(PyObject * _unused,PyObject * arg)852 static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) {
853   HANDLE_TH_ERRORS
854   TORCH_CHECK_TYPE(
855       PyBool_Check(arg),
856       "enabled must be a bool (got ",
857       Py_TYPE(arg)->tp_name,
858       ")");
859   at::autocast::set_autocast_cache_enabled(arg == Py_True);
860   Py_RETURN_NONE;
861   END_HANDLE_TH_ERRORS
862 }
863 
set_grad_enabled(PyObject * _unused,PyObject * args,PyObject * kwargs)864 static PyObject* set_grad_enabled(
865     PyObject* _unused,
866     PyObject* args,
867     PyObject* kwargs) {
868   HANDLE_TH_ERRORS
869   static PythonArgParser parser({
870       "set_grad_enabled(bool enabled)",
871   });
872   ParsedArgs<1> parsed_args;
873   auto r = parser.parse(args, kwargs, parsed_args);
874 
875   if (at::impl::torch_function_mode_enabled()) {
876     auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
877     return handle_torch_function(
878         r, args, kwargs, torch_C_module, "torch._C", "_set_grad_enabled");
879   }
880   auto grad_enabled = r.toBool(0);
881   GradMode::set_enabled(grad_enabled);
882   Py_RETURN_NONE;
883   END_HANDLE_TH_ERRORS
884 }
885 
is_grad_enabled(PyObject * _unused,PyObject * arg)886 static PyObject* is_grad_enabled(PyObject* _unused, PyObject* arg) {
887   HANDLE_TH_ERRORS
888   if (GradMode::is_enabled()) {
889     Py_RETURN_TRUE;
890   } else {
891     Py_RETURN_FALSE;
892   }
893   END_HANDLE_TH_ERRORS
894 }
895 
set_fwd_grad_enabled(PyObject * _unused,PyObject * arg)896 static PyObject* set_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
897   HANDLE_TH_ERRORS
898   TORCH_CHECK_TYPE(
899       PyBool_Check(arg),
900       "enabled must be a bool (got ",
901       Py_TYPE(arg)->tp_name,
902       ")");
903   c10::AutogradState::get_tls_state().set_fw_grad_mode(arg == Py_True);
904   Py_RETURN_NONE;
905   END_HANDLE_TH_ERRORS
906 }
907 
is_fwd_grad_enabled(PyObject * _unused,PyObject * arg)908 static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
909   HANDLE_TH_ERRORS
910   if (c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
911     Py_RETURN_TRUE;
912   } else {
913     Py_RETURN_FALSE;
914   }
915   END_HANDLE_TH_ERRORS
916 }
917 
set_multithreading_enabled(PyObject * self,PyObject * args,PyObject * kwargs)918 static PyObject* set_multithreading_enabled(
919     PyObject* self,
920     PyObject* args,
921     PyObject* kwargs) {
922   HANDLE_TH_ERRORS
923   static PythonArgParser parser({
924       "set_multithreading_enabled(bool enabled)",
925   });
926   ParsedArgs<1> parsed_args;
927   auto r = parser.parse(args, kwargs, parsed_args);
928 
929   if (at::impl::torch_function_mode_enabled()) {
930     auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
931     return handle_torch_function(
932         r,
933         args,
934         kwargs,
935         torch_C_module,
936         "torch._C",
937         "_set_multithreading_enabled");
938   }
939   auto multithreading_enabled = r.toBool(0);
940   c10::AutogradState::get_tls_state().set_multithreading_enabled(
941       multithreading_enabled);
942   Py_RETURN_NONE;
943   END_HANDLE_TH_ERRORS
944 }
945 
is_multithreading_enabled(PyObject * self,PyObject * args)946 static PyObject* is_multithreading_enabled(PyObject* self, PyObject* args) {
947   HANDLE_TH_ERRORS
948   if (c10::AutogradState::get_tls_state().get_multithreading_enabled()) {
949     Py_RETURN_TRUE;
950   } else {
951     Py_RETURN_FALSE;
952   }
953   END_HANDLE_TH_ERRORS
954 }
955 
set_view_replay_enabled(PyObject * self,PyObject * args,PyObject * kwargs)956 static PyObject* set_view_replay_enabled(
957     PyObject* self,
958     PyObject* args,
959     PyObject* kwargs) {
960   HANDLE_TH_ERRORS
961   static PythonArgParser parser({
962       "set_view_replay_enabled(bool enabled)",
963   });
964   ParsedArgs<1> parsed_args;
965   auto r = parser.parse(args, kwargs, parsed_args);
966 
967   if (at::impl::torch_function_mode_enabled()) {
968     auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
969     return handle_torch_function(
970         r,
971         args,
972         kwargs,
973         torch_C_module,
974         "torch._C",
975         "_set_view_replay_enabled");
976   }
977   auto view_replay_enabled = r.toBool(0);
978   c10::AutogradState::get_tls_state().set_view_replay_enabled(
979       view_replay_enabled);
980   Py_RETURN_NONE;
981   END_HANDLE_TH_ERRORS
982 }
983 
is_view_replay_enabled(PyObject * self,PyObject * args)984 static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) {
985   HANDLE_TH_ERRORS
986   if (c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
987     Py_RETURN_TRUE;
988   } else {
989     Py_RETURN_FALSE;
990   }
991   END_HANDLE_TH_ERRORS
992 }
993 
is_inference_mode_enabled(PyObject * _unused,PyObject * arg)994 static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
995   HANDLE_TH_ERRORS
996   if (c10::InferenceMode::is_enabled()) {
997     Py_RETURN_TRUE;
998   } else {
999     Py_RETURN_FALSE;
1000   }
1001   END_HANDLE_TH_ERRORS
1002 }
1003 
set_anomaly_mode_enabled(PyObject * _unused,PyObject * args,PyObject * kwargs)1004 static PyObject* set_anomaly_mode_enabled(
1005     PyObject* _unused,
1006     PyObject* args,
1007     PyObject* kwargs) {
1008   HANDLE_TH_ERRORS
1009   static PythonArgParser parser({
1010       "set_anomaly_enabled(bool enabled, bool check_nan=True)",
1011   });
1012   ParsedArgs<2> parsed_args;
1013   auto r = parser.parse(args, kwargs, parsed_args);
1014   AnomalyMode::set_enabled(r.toBool(0), r.toBool(1));
1015   Py_RETURN_NONE;
1016   END_HANDLE_TH_ERRORS
1017 }
1018 
is_anomaly_mode_enabled(PyObject * _unused,PyObject * arg)1019 static PyObject* is_anomaly_mode_enabled(PyObject* _unused, PyObject* arg) {
1020   HANDLE_TH_ERRORS
1021   if (AnomalyMode::is_enabled()) {
1022     Py_RETURN_TRUE;
1023   } else {
1024     Py_RETURN_FALSE;
1025   }
1026   END_HANDLE_TH_ERRORS
1027 }
1028 
is_anomaly_check_nan_enabled(PyObject * _unused,PyObject * arg)1029 static PyObject* is_anomaly_check_nan_enabled(
1030     PyObject* _unused,
1031     PyObject* arg) {
1032   HANDLE_TH_ERRORS
1033   if (AnomalyMode::should_check_nan()) {
1034     Py_RETURN_TRUE;
1035   } else {
1036     Py_RETURN_FALSE;
1037   }
1038   END_HANDLE_TH_ERRORS
1039 }
1040 
python_enter_dual_level(PyObject * _unused,PyObject * arg)1041 static PyObject* python_enter_dual_level(PyObject* _unused, PyObject* arg) {
1042   HANDLE_TH_ERRORS
1043   // It is unlikely that the depth of forward nesting will overflow int64_t so
1044   // we just static cast here.
1045   return utils::wrap(static_cast<int64_t>(forward_ad::enter_dual_level()));
1046   END_HANDLE_TH_ERRORS
1047 }
1048 
python_exit_dual_level(PyObject * _unused,PyObject * args,PyObject * kwargs)1049 static PyObject* python_exit_dual_level(
1050     PyObject* _unused,
1051     PyObject* args,
1052     PyObject* kwargs) {
1053   HANDLE_TH_ERRORS
1054   static PythonArgParser parser({"exit_dual_level(int64_t level)"});
1055 
1056   ParsedArgs<1> parsed_args;
1057   auto _r = parser.parse(args, kwargs, parsed_args);
1058 
1059   auto idx = _r.toInt64(0);
1060   // Make sure the given index is valid before casting it
1061   TORCH_CHECK(idx >= 0, "Dual level must be a positive number.");
1062   forward_ad::exit_dual_level(static_cast<uint64_t>(idx));
1063   Py_RETURN_NONE;
1064   END_HANDLE_TH_ERRORS
1065 }
1066 
is_torch_function_mode_enabled(PyObject * _unused,PyObject * _unused2)1067 static PyObject* is_torch_function_mode_enabled(
1068     PyObject* _unused,
1069     PyObject* _unused2) {
1070   HANDLE_TH_ERRORS
1071   if (at::impl::torch_function_mode_enabled()) {
1072     Py_RETURN_TRUE;
1073   } else {
1074     Py_RETURN_FALSE;
1075   }
1076   END_HANDLE_TH_ERRORS
1077 }
1078 
push_on_torch_function_stack(PyObject * _unused,PyObject * arg)1079 static PyObject* push_on_torch_function_stack(
1080     PyObject* _unused,
1081     PyObject* arg) {
1082   HANDLE_TH_ERRORS
1083   if (arg != Py_None) {
1084     Py_INCREF(arg);
1085     at::impl::PythonTorchFunctionTLS::push_onto_stack(
1086         std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
1087   }
1088   Py_RETURN_NONE;
1089   END_HANDLE_TH_ERRORS
1090 }
1091 
pop_torch_function_stack(PyObject * _unused,PyObject * _unused2)1092 static PyObject* pop_torch_function_stack(
1093     PyObject* _unused,
1094     PyObject* _unused2) {
1095   HANDLE_TH_ERRORS
1096   const auto& mode = at::impl::PythonTorchFunctionTLS::pop_stack();
1097   auto* r = mode->ptr(getPyInterpreter());
1098   Py_INCREF(r);
1099   return r;
1100   END_HANDLE_TH_ERRORS
1101 }
1102 
get_function_stack_at(PyObject * _unused,PyObject * args,PyObject * kwargs)1103 static PyObject* get_function_stack_at(
1104     PyObject* _unused,
1105     PyObject* args,
1106     PyObject* kwargs) {
1107   HANDLE_TH_ERRORS
1108   static PythonArgParser parser({"get_stack_at(int64_t level)"});
1109 
1110   ParsedArgs<1> parsed_args;
1111   auto _r = parser.parse(args, kwargs, parsed_args);
1112 
1113   auto idx = _r.toInt64(0);
1114   const auto& mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
1115   auto* r = mode->ptr(getPyInterpreter());
1116   Py_INCREF(r);
1117   return r;
1118   END_HANDLE_TH_ERRORS
1119 }
1120 
len_torch_function_stack(PyObject * _unused,PyObject * _unused2)1121 static PyObject* len_torch_function_stack(
1122     PyObject* _unused,
1123     PyObject* _unused2) {
1124   HANDLE_TH_ERRORS
1125   const auto len = at::impl::PythonTorchFunctionTLS::stack_len();
1126   return utils::wrap(static_cast<int64_t>(len));
1127   END_HANDLE_TH_ERRORS
1128 }
1129 
push_on_torch_dispatch_stack(PyObject * _unused,PyObject * arg)1130 static PyObject* push_on_torch_dispatch_stack(
1131     PyObject* _unused,
1132     PyObject* arg) {
1133   HANDLE_TH_ERRORS
1134   if (arg != Py_None) {
1135     using c10::impl::TorchDispatchModeKey;
1136     // When we push a mode onto the mode stack, we need to
1137     // check if it's an "infra" mode, by checking its _mode_key attribute.
1138     std::optional<c10::impl::TorchDispatchModeKey> mode_key = std::nullopt;
1139     py::object maybe_mode_key_obj =
1140         PyObject_FastGetAttrString(arg, "_mode_key");
1141     if (maybe_mode_key_obj) {
1142       mode_key = py::cast<c10::impl::TorchDispatchModeKey>(maybe_mode_key_obj);
1143       c10::impl::TorchDispatchModeTLS::set_mode(
1144           std::make_shared<c10::impl::PyObject_TorchDispatchMode>(
1145               arg, getPyInterpreter()),
1146           mode_key.value());
1147     } else {
1148       c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
1149           std::make_shared<c10::impl::PyObject_TorchDispatchMode>(
1150               arg, getPyInterpreter()));
1151     }
1152     Py_INCREF(arg);
1153   }
1154   Py_RETURN_NONE;
1155   END_HANDLE_TH_ERRORS
1156 }
1157 
pop_torch_dispatch_stack(PyObject * _unused,PyObject * maybe_mode_key)1158 static PyObject* pop_torch_dispatch_stack(
1159     PyObject* _unused,
1160     PyObject* maybe_mode_key) {
1161   HANDLE_TH_ERRORS
1162   std::optional<c10::impl::TorchDispatchModeKey> mode_key = std::nullopt;
1163   PyObject* r = nullptr;
1164   if (maybe_mode_key != Py_None) {
1165     mode_key = py::cast<c10::impl::TorchDispatchModeKey>(maybe_mode_key);
1166     auto maybe_mode =
1167         c10::impl::TorchDispatchModeTLS::unset_mode(mode_key.value());
1168     TORCH_CHECK(
1169         maybe_mode.has_value(),
1170         "Attempted to unset ",
1171         c10::impl::to_string(mode_key.value()),
1172         ", but there wasn't one active.");
1173     auto mode = maybe_mode.value();
1174     r = mode->ptr(getPyInterpreter());
1175   } else {
1176     auto mode = c10::impl::TorchDispatchModeTLS::pop_stack();
1177     r = mode->ptr(getPyInterpreter());
1178   }
1179   Py_INCREF(r);
1180   return r;
1181   END_HANDLE_TH_ERRORS
1182 }
1183 
get_dispatch_stack_at(PyObject * _unused,PyObject * args,PyObject * kwargs)1184 static PyObject* get_dispatch_stack_at(
1185     PyObject* _unused,
1186     PyObject* args,
1187     PyObject* kwargs) {
1188   HANDLE_TH_ERRORS
1189   static PythonArgParser parser({"get_stack_at(int64_t level)"});
1190 
1191   ParsedArgs<1> parsed_args;
1192   auto _r = parser.parse(args, kwargs, parsed_args);
1193 
1194   auto idx = _r.toInt64(0);
1195   const auto& mode = c10::impl::TorchDispatchModeTLS::get_stack_at(idx);
1196   auto* r = mode->ptr(getPyInterpreter());
1197   Py_INCREF(r);
1198   return r;
1199   END_HANDLE_TH_ERRORS
1200 }
1201 
set_dispatch_mode(PyObject * _unused,PyObject * mode)1202 static PyObject* set_dispatch_mode(PyObject* _unused, PyObject* mode) {
1203   HANDLE_TH_ERRORS
1204   TORCH_CHECK(mode != Py_None);
1205 
1206   py::object maybe_mode_key_obj = PyObject_FastGetAttrString(mode, "_mode_key");
1207   TORCH_CHECK(
1208       maybe_mode_key_obj,
1209       "set_dispatch_mode() called with a mode that does not contain a _mode_key attribute!");
1210   auto mode_key = py::cast<c10::impl::TorchDispatchModeKey>(maybe_mode_key_obj);
1211 
1212   Py_INCREF(mode);
1213   c10::impl::TorchDispatchModeTLS::set_mode(
1214       std::make_shared<c10::impl::PyObject_TorchDispatchMode>(
1215           mode, getPyInterpreter()),
1216       mode_key);
1217 
1218   Py_RETURN_NONE;
1219   END_HANDLE_TH_ERRORS
1220 }
1221 
get_dispatch_mode(PyObject * _unused,PyObject * arg)1222 static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) {
1223   HANDLE_TH_ERRORS
1224   TORCH_CHECK(arg != Py_None);
1225   auto mode_key = py::cast<c10::impl::TorchDispatchModeKey>(arg);
1226 
1227   auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(mode_key);
1228   if (maybe_mode == std::nullopt) {
1229     Py_RETURN_NONE;
1230   }
1231   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1232   auto* r = maybe_mode.value()->ptr(getPyInterpreter());
1233   Py_INCREF(r);
1234   return r;
1235   END_HANDLE_TH_ERRORS
1236 }
1237 
unset_dispatch_mode(PyObject * _unused,PyObject * arg)1238 static PyObject* unset_dispatch_mode(PyObject* _unused, PyObject* arg) {
1239   HANDLE_TH_ERRORS
1240   TORCH_CHECK(arg != Py_None);
1241   auto mode_key = py::cast<c10::impl::TorchDispatchModeKey>(arg);
1242 
1243   const auto maybe_mode = c10::impl::TorchDispatchModeTLS::unset_mode(mode_key);
1244   if (maybe_mode == std::nullopt) {
1245     Py_RETURN_NONE;
1246   }
1247   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1248   auto* r = maybe_mode.value()->ptr(getPyInterpreter());
1249   Py_INCREF(r);
1250   return r;
1251   END_HANDLE_TH_ERRORS
1252 }
1253 
len_torch_dispatch_stack(PyObject * _unused,PyObject * args)1254 static PyObject* len_torch_dispatch_stack(PyObject* _unused, PyObject* args) {
1255   HANDLE_TH_ERRORS
1256   const auto len = c10::impl::TorchDispatchModeTLS::stack_len();
1257   return utils::wrap(static_cast<int64_t>(len));
1258   END_HANDLE_TH_ERRORS
1259 }
1260 
THPModule_increment_version(PyObject * _unused,PyObject * tensor_list)1261 PyObject* THPModule_increment_version(
1262     PyObject* _unused,
1263     PyObject* tensor_list) {
1264   HANDLE_TH_ERRORS
1265   auto iterator = THPObjectPtr(PyObject_GetIter(tensor_list));
1266   TORCH_CHECK(iterator, "increment_version expect a Iterable[Tensor] as input");
1267   auto item = THPObjectPtr(PyIter_Next(iterator));
1268   while (item) {
1269     TORCH_CHECK(
1270         THPVariable_Check(item),
1271         "increment_version expects each element of the iterable to be a tensor");
1272     auto t = THPVariable_Unpack(item);
1273     if (!t.is_inference()) {
1274       torch::autograd::increment_version(t);
1275     }
1276     item = THPObjectPtr(PyIter_Next(iterator));
1277   }
1278   Py_RETURN_NONE;
1279   END_HANDLE_TH_ERRORS
1280 }
1281 
1282 // autograd methods on torch._C
1283 static PyMethodDef methods[] = {
1284     {"_set_grad_enabled",
1285      castPyCFunctionWithKeywords(set_grad_enabled),
1286      METH_VARARGS | METH_KEYWORDS,
1287      nullptr},
1288     {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
1289     {"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr},
1290     {"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr},
1291     {"is_inference_mode_enabled",
1292      is_inference_mode_enabled,
1293      METH_NOARGS,
1294      nullptr},
1295     {"set_autocast_enabled",
1296      castPyCFunctionWithKeywords(set_autocast_enabled),
1297      METH_VARARGS | METH_KEYWORDS,
1298      nullptr},
1299     {"is_autocast_enabled",
1300      castPyCFunctionWithKeywords(is_autocast_enabled),
1301      METH_VARARGS | METH_KEYWORDS,
1302      nullptr},
1303     {"set_autocast_dtype",
1304      castPyCFunctionWithKeywords(set_autocast_dtype),
1305      METH_VARARGS | METH_KEYWORDS,
1306      nullptr},
1307     {"get_autocast_dtype",
1308      castPyCFunctionWithKeywords(get_autocast_dtype),
1309      METH_VARARGS | METH_KEYWORDS,
1310      nullptr},
1311     {"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
1312     {"_is_autocast_available",
1313      castPyCFunctionWithKeywords(is_autocast_available),
1314      METH_VARARGS | METH_KEYWORDS,
1315      nullptr},
1316     {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
1317     {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
1318     {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},
1319     {"set_autocast_cpu_dtype", set_autocast_cpu_dtype, METH_O, nullptr},
1320     {"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr},
1321     {"set_autocast_gpu_dtype", set_autocast_gpu_dtype, METH_O, nullptr},
1322     {"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr},
1323     {"set_autocast_xla_enabled", set_autocast_xla_enabled, METH_O, nullptr},
1324     {"is_autocast_xla_enabled", is_autocast_xla_enabled, METH_NOARGS, nullptr},
1325     {"set_autocast_xla_dtype", set_autocast_xla_dtype, METH_O, nullptr},
1326     {"get_autocast_xla_dtype", get_autocast_xla_dtype, METH_NOARGS, nullptr},
1327     {"set_autocast_ipu_enabled", set_autocast_ipu_enabled, METH_O, nullptr},
1328     {"is_autocast_ipu_enabled", is_autocast_ipu_enabled, METH_NOARGS, nullptr},
1329     {"set_autocast_ipu_dtype", set_autocast_ipu_dtype, METH_O, nullptr},
1330     {"get_autocast_ipu_dtype", get_autocast_ipu_dtype, METH_NOARGS, nullptr},
1331     {"autocast_increment_nesting",
1332      autocast_increment_nesting,
1333      METH_NOARGS,
1334      nullptr},
1335     {"autocast_decrement_nesting",
1336      autocast_decrement_nesting,
1337      METH_NOARGS,
1338      nullptr},
1339     {"is_autocast_cache_enabled",
1340      is_autocast_cache_enabled,
1341      METH_NOARGS,
1342      nullptr},
1343     {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr},
1344     {"_increment_version", THPModule_increment_version, METH_O, nullptr},
1345     {"set_anomaly_enabled",
1346      castPyCFunctionWithKeywords(set_anomaly_mode_enabled),
1347      METH_VARARGS | METH_KEYWORDS,
1348      nullptr},
1349     {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
1350     {"is_anomaly_check_nan_enabled",
1351      is_anomaly_check_nan_enabled,
1352      METH_NOARGS,
1353      nullptr},
1354     {"_is_multithreading_enabled",
1355      is_multithreading_enabled,
1356      METH_NOARGS,
1357      nullptr},
1358     {"_set_multithreading_enabled",
1359      castPyCFunctionWithKeywords(set_multithreading_enabled),
1360      METH_VARARGS | METH_KEYWORDS,
1361      nullptr},
1362     {"_is_view_replay_enabled", is_view_replay_enabled, METH_NOARGS, nullptr},
1363     {"_set_view_replay_enabled",
1364      castPyCFunctionWithKeywords(set_view_replay_enabled),
1365      METH_VARARGS | METH_KEYWORDS,
1366      nullptr},
1367     {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
1368     {"_exit_dual_level",
1369      castPyCFunctionWithKeywords(python_exit_dual_level),
1370      METH_VARARGS | METH_KEYWORDS,
1371      nullptr},
1372     {"_is_torch_function_mode_enabled",
1373      is_torch_function_mode_enabled,
1374      METH_NOARGS,
1375      nullptr},
1376     {"_push_on_torch_function_stack",
1377      push_on_torch_function_stack,
1378      METH_O,
1379      nullptr},
1380     {"_pop_torch_function_stack",
1381      pop_torch_function_stack,
1382      METH_NOARGS,
1383      nullptr},
1384     {"_get_function_stack_at",
1385      castPyCFunctionWithKeywords(get_function_stack_at),
1386      METH_VARARGS | METH_KEYWORDS,
1387      nullptr},
1388     {"_len_torch_function_stack",
1389      len_torch_function_stack,
1390      METH_NOARGS,
1391      nullptr},
1392     {"_push_on_torch_dispatch_stack",
1393      push_on_torch_dispatch_stack,
1394      METH_O,
1395      nullptr},
1396     {"_pop_torch_dispatch_stack", pop_torch_dispatch_stack, METH_O, nullptr},
1397     {"_get_dispatch_stack_at",
1398      castPyCFunctionWithKeywords(get_dispatch_stack_at),
1399      METH_VARARGS | METH_KEYWORDS,
1400      nullptr},
1401     {"_len_torch_dispatch_stack",
1402      len_torch_dispatch_stack,
1403      METH_NOARGS,
1404      nullptr},
1405     {"_set_dispatch_mode", set_dispatch_mode, METH_O, nullptr},
1406     {"_get_dispatch_mode", get_dispatch_mode, METH_O, nullptr},
1407     {"_unset_dispatch_mode", unset_dispatch_mode, METH_O, nullptr},
1408 
1409     {nullptr, nullptr, 0, nullptr}};
1410 
python_functions()1411 PyMethodDef* python_functions() {
1412   return methods;
1413 }
1414 
1415 } // namespace torch::autograd
1416