1 #include <torch/csrc/python_headers.h>
2
3 #include <ATen/PythonTorchFunctionTLS.h>
4 #include <ATen/SavedTensorHooks.h>
5 #include <ATen/SequenceNumber.h>
6 #include <ATen/autocast_mode.h>
7 #include <ATen/core/PythonFallbackKernel.h>
8 #include <ATen/record_function.h>
9 #include <c10/core/DeviceType.h>
10 #include <c10/core/InferenceMode.h>
11 #include <c10/core/ScalarType.h>
12 #include <c10/core/impl/PythonDispatcherTLS.h>
13 #include <torch/csrc/Exceptions.h>
14 #include <torch/csrc/autograd/VariableTypeUtils.h>
15 #include <torch/csrc/autograd/autograd.h>
16 #include <torch/csrc/autograd/autograd_not_implemented_fallback.h>
17 #include <torch/csrc/autograd/function.h>
18 #include <torch/csrc/autograd/grad_mode.h>
19 #include <torch/csrc/autograd/input_metadata.h>
20 #include <torch/csrc/autograd/profiler.h>
21 #include <torch/csrc/autograd/profiler_python.h>
22 #include <torch/csrc/autograd/python_function.h>
23 #include <torch/csrc/autograd/python_saved_variable_hooks.h>
24 #include <torch/csrc/autograd/python_variable.h>
25 #include <torch/csrc/autograd/record_function_ops.h>
26 #include <torch/csrc/autograd/saved_variable.h>
27 #include <torch/csrc/autograd/utils/python_arg_parsing.h>
28 #include <torch/csrc/autograd/utils/wrap_outputs.h>
29 #include <torch/csrc/jit/python/pybind_utils.h>
30 #include <torch/csrc/profiler/collection.h>
31 #include <torch/csrc/profiler/kineto_shim.h>
32 #include <torch/csrc/utils.h>
33 #include <torch/csrc/utils/disable_torch_function.h>
34 #include <torch/csrc/utils/pybind.h>
35 #include <torch/csrc/utils/pycfunction_helpers.h>
36 #include <torch/csrc/utils/python_raii.h>
37 #include <torch/csrc/utils/python_torch_function_mode.h>
38
39 #include <set>
40 #include <unordered_set>
41 #include <utility>
42
43 using torch::impl::py_context_manager;
44 using torch::impl::py_context_manager_DEPRECATED;
45
46 namespace {
47
48 struct DisableFuncTorch {
DisableFuncTorch__anon67aec8d30111::DisableFuncTorch49 DisableFuncTorch()
50 : front_guard_(c10::DispatchKey::FuncTorchDynamicLayerFrontMode),
51 back_guard_(c10::DispatchKey::FuncTorchDynamicLayerBackMode) {}
52 c10::impl::ExcludeDispatchKeyGuard front_guard_;
53 c10::impl::ExcludeDispatchKeyGuard back_guard_;
54 };
55
56 struct DisableAutocast {
57 c10::impl::ExcludeDispatchKeyGuard guard_{c10::autocast_dispatch_keyset};
58 };
59
60 struct EnableTorchFunction {
EnableTorchFunction__anon67aec8d30111::EnableTorchFunction61 EnableTorchFunction()
62 : old_(at::impl::PythonTorchFunctionTLS::get_disabled_state()) {
63 at::impl::PythonTorchFunctionTLS::set_disabled_state(
64 at::impl::TorchFunctionDisabledState::ENABLED);
65 }
~EnableTorchFunction__anon67aec8d30111::EnableTorchFunction66 ~EnableTorchFunction() {
67 at::impl::PythonTorchFunctionTLS::set_disabled_state(old_);
68 }
69 at::impl::TorchFunctionDisabledState old_;
70 };
71
72 struct EnablePythonDispatcher {
EnablePythonDispatcher__anon67aec8d30111::EnablePythonDispatcher73 EnablePythonDispatcher() : old_(c10::impl::PythonDispatcherTLS::get_state()) {
74 c10::impl::PythonDispatcherTLS::set_state(getPyInterpreter());
75 }
~EnablePythonDispatcher__anon67aec8d30111::EnablePythonDispatcher76 ~EnablePythonDispatcher() {
77 c10::impl::PythonDispatcherTLS::set_state(old_);
78 }
79 c10::impl::PyInterpreter* old_;
80 };
81
82 struct EnablePreDispatch {
EnablePreDispatch__anon67aec8d30111::EnablePreDispatch83 EnablePreDispatch() : guard_(c10::DispatchKey::PreDispatch) {}
84 c10::impl::IncludeDispatchKeyGuard guard_;
85 };
86
87 } // namespace
88
THPAutograd_initExtension(PyObject * _unused,PyObject * unused)89 PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject* unused) {
90 using namespace torch::autograd::profiler;
91 using namespace torch::profiler::impl;
92 auto tensor_module = THPObjectPtr(PyImport_ImportModule("torch._tensor"));
93 if (!tensor_module)
94 return nullptr;
95
96 // NOTE: "leaks" THPVariableClass
97 THPVariableClass = PyObject_GetAttrString(tensor_module, "Tensor");
98 if (!THPVariableClass)
99 return nullptr;
100
101 auto autograd_module = THPObjectPtr(PyImport_ImportModule("torch.autograd"));
102 if (!autograd_module)
103 return nullptr;
104
105 // NOTE: "leaks" Function
106 THPFunctionClass = PyObject_GetAttrString(autograd_module, "Function");
107 if (!THPFunctionClass)
108 return nullptr;
109
110 // NOTE: "leaks" GradientEdge
111 auto autograd_graph_mod =
112 THPObjectPtr(PyImport_ImportModule("torch.autograd.graph"));
113 THPGradientEdgeClass =
114 PyObject_GetAttrString(autograd_graph_mod, "GradientEdge");
115 if (!THPGradientEdgeClass)
116 return nullptr;
117
118 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
119 if (!torch_C_module)
120 return nullptr;
121 auto _C_m = py::handle(torch_C_module).cast<py::module>();
122 auto m = _C_m.def_submodule("_autograd", "autograd bindings");
123
124 auto parameter_module =
125 THPObjectPtr(PyImport_ImportModule("torch.nn.parameter"));
126 if (!parameter_module)
127 return nullptr;
128
129 // NOTE: "leaks" ParameterClass
130 ParameterClass = PyObject_GetAttrString(parameter_module, "Parameter");
131 if (!ParameterClass)
132 return nullptr;
133
134 py::class_<LegacyEvent>(m, "ProfilerEvent")
135 .def("kind", &LegacyEvent::kindStr)
136 .def("name", [](const LegacyEvent& e) { return e.name(); })
137 .def("thread_id", &LegacyEvent::threadId)
138 .def("fwd_thread_id", &LegacyEvent::fwdThreadId)
139 .def("device", &LegacyEvent::device)
140 .def("cpu_elapsed_us", &LegacyEvent::cpuElapsedUs)
141 .def("cuda_elapsed_us", &LegacyEvent::cudaElapsedUs)
142 .def("has_cuda", &LegacyEvent::hasCuda)
143 .def("shapes", &LegacyEvent::shapes)
144 .def("cpu_memory_usage", &LegacyEvent::cpuMemoryUsage)
145 .def("cuda_memory_usage", &LegacyEvent::cudaMemoryUsage)
146 .def("handle", &LegacyEvent::handle)
147 .def("node_id", &LegacyEvent::nodeId)
148 .def("is_remote", &LegacyEvent::isRemote)
149 .def("sequence_nr", &LegacyEvent::sequenceNr)
150 .def("stack", &LegacyEvent::stack)
151 .def("scope", &LegacyEvent::scope)
152 .def("correlation_id", &LegacyEvent::correlationId)
153 .def("start_us", &LegacyEvent::cpuUs)
154 .def("flops", &LegacyEvent::flops)
155 .def("is_async", &LegacyEvent::isAsync);
156
157 py::enum_<c10::DeviceType>(m, "DeviceType")
158 .value("CPU", c10::DeviceType::CPU)
159 .value("CUDA", c10::DeviceType::CUDA)
160 .value("MKLDNN", c10::DeviceType::MKLDNN)
161 .value("OPENGL", c10::DeviceType::OPENGL)
162 .value("OPENCL", c10::DeviceType::OPENCL)
163 .value("IDEEP", c10::DeviceType::IDEEP)
164 .value("HIP", c10::DeviceType::HIP)
165 .value("FPGA", c10::DeviceType::FPGA)
166 .value("MAIA", c10::DeviceType::MAIA)
167 .value("XLA", c10::DeviceType::XLA)
168 .value("Vulkan", c10::DeviceType::Vulkan)
169 .value("Metal", c10::DeviceType::Metal)
170 .value("XPU", c10::DeviceType::XPU)
171 .value("MPS", c10::DeviceType::MPS)
172 .value("MTIA", c10::DeviceType::MTIA)
173 .value("Meta", c10::DeviceType::Meta)
174 .value("HPU", c10::DeviceType::HPU)
175 .value("VE", c10::DeviceType::VE)
176 .value("Lazy", c10::DeviceType::Lazy)
177 .value("IPU", c10::DeviceType::IPU)
178 .value("PrivateUse1", c10::DeviceType::PrivateUse1);
179
180 using torch::autograd::CreationMeta;
181 py::enum_<CreationMeta>(m, "CreationMeta")
182 .value("DEFAULT", CreationMeta::DEFAULT)
183 .value("IN_CUSTOM_FUNCTION", CreationMeta::IN_CUSTOM_FUNCTION)
184 .value("MULTI_OUTPUT_NODE", CreationMeta::MULTI_OUTPUT_NODE)
185 .value("NO_GRAD_MODE", CreationMeta::NO_GRAD_MODE)
186 .value("INFERENCE_MODE", CreationMeta::INFERENCE_MODE);
187
188 py::class_<torch::autograd::InputMetadata>(m, "_InputMetadata")
189 .def_property_readonly(
190 "dtype",
191 [](const torch::autograd::InputMetadata& m) {
192 PyObject* raw_obj =
193 (PyObject*)torch::getTHPDtype(m.dtype().toScalarType());
194 return py::reinterpret_borrow<py::object>(raw_obj);
195 })
196 .def_property_readonly("device", &torch::autograd::InputMetadata::device)
197 .def_property_readonly(
198 "shape", &torch::autograd::InputMetadata::shape_as_dim_vector)
199 .def_property_readonly(
200 "is_nested_tensor", &torch::autograd::InputMetadata::is_nested_tensor)
201 .def_property_readonly(
202 "is_cpp_nested_tensor",
203 &torch::autograd::InputMetadata::is_cpp_nested_tensor);
204
205 py::class_<KinetoEvent>(m, "_KinetoEvent")
206 // name of the event
207 .def("name", [](const KinetoEvent& e) { return e.name(); })
208 // PyTorch thread id of the start callback
209 .def(
210 "start_thread_id",
211 [](const KinetoEvent& e) { return e.startThreadId(); })
212 // PyTorch thread id of the end callback
213 .def(
214 "end_thread_id", [](const KinetoEvent& e) { return e.endThreadId(); })
215 // for events of scope BACKWARD_FUNCTION - PyTorch thread id
216 // of the corresponding forward op
217 .def(
218 "fwd_thread_id", [](const KinetoEvent& e) { return e.fwdThreadId(); })
219 // together with fwd_thread_id, used to uniquely identify
220 // the forward op
221 .def("sequence_nr", [](const KinetoEvent& e) { return e.sequenceNr(); })
222 // absolute start time (since unix epoch) in ns
223 .def("start_ns", [](const KinetoEvent& e) { return e.startNs(); })
224 // absolute end time (since unix epoch) in ns
225 .def("end_ns", [](const KinetoEvent& e) { return e.endNs(); })
226 // duration in ns
227 .def("duration_ns", [](const KinetoEvent& e) { return e.durationNs(); })
228 // used for correlation between high-level PyTorch events
229 // and low-level device events
230 .def(
231 "correlation_id",
232 [](const KinetoEvent& e) { return e.correlationId(); })
233 // shapes of input tensors
234 .def("shapes", [](const KinetoEvent& e) { return e.shapes().vec(); })
235 .def("dtypes", [](const KinetoEvent& e) { return e.dtypes().vec(); })
236 .def(
237 "concrete_inputs",
238 [](const KinetoEvent& e) {
239 std::vector<py::object> as_pyobj;
240 std::transform(
241 e.concreteInputs().begin(),
242 e.concreteInputs().end(),
243 std::back_inserter(as_pyobj),
244 [](const c10::IValue& val) {
245 return torch::jit::toPyObject(val);
246 });
247 return as_pyobj;
248 })
249 .def(
250 "kwinputs",
251 [](const KinetoEvent& e) {
252 std::unordered_map<std::string, py::object> inputs;
253 for (const auto& [key, value] : e.kwinputs()) {
254 inputs[key] = torch::jit::toPyObject(value);
255 }
256 return inputs;
257 })
258 // stack traces of the PyTorch CPU events
259 .def("stack", [](const KinetoEvent& e) { return e.stack().vec(); })
260 // type of the RecordFunction that generated a PyTorch CPU event
261 // (op, torchscript function, user label, etc)
262 .def("scope", [](const KinetoEvent& e) { return e.scope(); })
263 // device number, for CPU - process id
264 .def("device_index", [](const KinetoEvent& e) { return e.deviceIndex(); })
265 // for CUDA - stream id, for CPU - start thread id
266 .def(
267 "device_resource_id",
268 [](const KinetoEvent& e) { return e.deviceResourceId(); })
269 // device type
270 .def("device_type", [](const KinetoEvent& e) { return e.deviceType(); })
271 // correlation id of a linked event
272 .def(
273 "linked_correlation_id",
274 [](const KinetoEvent& e) { return e.linkedCorrelationId(); })
275 // compute flops
276 .def("flops", [](const KinetoEvent& e) { return e.flops(); })
277 // Whether this is async event or not
278 .def("is_async", [](const KinetoEvent& e) { return e.isAsync(); })
279 .def("cuda_elapsed_us", &KinetoEvent::cudaElapsedUs)
280 .def("privateuse1_elapsed_us", &KinetoEvent::privateuse1ElapsedUs)
281 .def(
282 "is_user_annotation",
283 [](const KinetoEvent& e) {
284 return e.activityType() ==
285 (uint8_t)libkineto::ActivityType::USER_ANNOTATION ||
286 e.activityType() ==
287 (uint8_t)libkineto::ActivityType::GPU_USER_ANNOTATION;
288 })
289 .def("nbytes", [](const KinetoEvent& e) { return e.nBytes(); });
290
291 m.def("_soft_assert_raises", &setSoftAssertRaises);
292 m.def("_get_sequence_nr", &at::sequence_number::peek);
293
294 py::class_<ProfilerResult>(m, "_ProfilerResult")
295 .def("trace_start_ns", &ProfilerResult::trace_start_ns)
296 .def("events", &ProfilerResult::events)
297 .def("experimental_event_tree", &ProfilerResult::event_tree)
298 #ifdef USE_KINETO
299 .def("save", &ProfilerResult::save)
300 #endif // USE_KINETO
301 ;
302
303 m.def(
304 "_enable_profiler",
305 &enableProfiler,
306 py::arg("config"),
307 py::arg("activities"),
308 py::arg("scopes") = std::unordered_set<at::RecordScope>());
309 m.def("_disable_profiler", disableProfiler);
310 m.def(
311 "_prepare_profiler",
312 prepareProfiler,
313 py::call_guard<py::gil_scoped_release>());
314 m.def(
315 "_toggle_collection_dynamic",
316 toggleCollectionDynamic,
317 py::call_guard<py::gil_scoped_release>());
318 m.def("_add_metadata_json", addMetadataJson); // Only if `USE_KINETO` is set
319 m.def("_kineto_step", profilerStep); // Only if `USE_KINETO` is set
320 m.def("kineto_available", []() { return torch::profiler::kKinetoAvailable; });
321
322 // NOTICE: These record functions are not torch operators and may not show up
323 // in TorchScript tracing, FX transforms, or operator serialization. For these
324 // use cases, please use `torch.profiler.record_function`.
325 // Creates a new profiling scope using RecordFunction and invokes its starting
326 // callbacks.
327 m.def(
328 "_record_function_with_args_enter",
329 [](const std::string& name, const py::args& args) {
330 using torch::autograd::profiler::PythonRecordFunction;
331 auto python_rec = c10::make_intrusive<PythonRecordFunction>(
332 at::RecordScope::USER_SCOPE);
333 auto* rec = &python_rec->record;
334 if (rec->isActive()) {
335 if (rec->needsInputs()) {
336 auto iv_inputs = std::vector<c10::IValue>();
337 for (const auto& arg : args) {
338 iv_inputs.push_back(torch::jit::toTypeInferredIValue(arg));
339 }
340 rec->before(
341 name,
342 c10::ArrayRef<const c10::IValue>(
343 iv_inputs.data(), iv_inputs.size()));
344 } else {
345 rec->before(name);
346 }
347 }
348 return torch::jit::toPyObject(std::move(python_rec));
349 });
350
351 // Ends the profiling scope created with record_function_with_param_enter.
352 m.def("_record_function_with_args_exit", [](const py::object& obj) {
353 using torch::autograd::profiler::PythonRecordFunction;
354 auto python_record = torch::jit::toCustomClass<PythonRecordFunction>(obj);
355
356 // We don't actually need to do anything with handle just need to persist
357 // the lifetime until now.
358 python_record->record.end();
359 });
360
361 m.def("_supported_activities", []() {
362 std::set<torch::profiler::impl::ActivityType> activities{
363 torch::profiler::impl::ActivityType::CPU};
364 #if defined(USE_KINETO) && \
365 (!defined(LIBKINETO_NOCUPTI) || !defined(LIBKINETO_NOROCTRACER))
366 if (at::hasMTIA()) {
367 activities.insert(torch::profiler::impl::ActivityType::MTIA);
368 }
369 if (at::getNumGPUs() > 0) {
370 activities.insert(torch::profiler::impl::ActivityType::CUDA);
371 }
372 #elif defined(USE_KINETO)
373 if (at::hasXPU()) {
374 activities.insert(torch::profiler::impl::ActivityType::XPU);
375 }
376 if (at::hasMTIA()) {
377 activities.insert(torch::profiler::impl::ActivityType::MTIA);
378 }
379 if (c10::get_privateuse1_backend() != "privateuseone") {
380 activities.insert(torch::profiler::impl::ActivityType::PrivateUse1);
381 }
382 #endif
383 return activities;
384 });
385
386 m.def("_unsafe_set_version_counter", [](const at::Tensor& t, int64_t i) {
387 auto vc = torch::autograd::impl::version_counter(t);
388 vc.set_version(i);
389 });
390
391 m.def("_enable_profiler_legacy", enableProfilerLegacy);
392 py::class_<ProfilerDisableOptions>(m, "_ProfilerDisableOptions")
393 .def(py::init<bool, bool>());
394 m.def(
395 "_disable_profiler_legacy",
396 disableProfilerLegacy,
397 py::arg("profiler_disable_options") = ProfilerDisableOptions());
398 m.def("_profiler_enabled", profilerEnabled);
399 m.def("_profiler_type", torch::profiler::impl::profilerType);
400 m.def("_enable_record_function", [](bool enable) {
401 at::enableRecordFunction(enable);
402 });
403 m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
404 auto cb =
405 at::RecordFunctionCallback(nullptr).needsInputs(true).samplingProb(
406 sampling_prob);
407 if (is_global) {
408 at::addGlobalCallback(cb);
409 } else {
410 at::addThreadLocalCallback(cb);
411 }
412 });
413 m.def("_clear_callbacks", []() { at::clearCallbacks(); });
414 m.def(
415 "_saved_tensors_hooks_is_enabled",
416 at::SavedTensorDefaultHooks::is_enabled);
417 m.def("_saved_tensors_hooks_enable", at::SavedTensorDefaultHooks::enable);
418 m.def("_saved_tensors_hooks_disable", at::SavedTensorDefaultHooks::disable);
419 m.def(
420 "_saved_tensors_hooks_set_tracing",
421 at::SavedTensorDefaultHooks::set_tracing);
422 m.def(
423 "_saved_tensors_hooks_get_disabled_error_message",
424 at::SavedTensorDefaultHooks::get_disabled_error_message);
425 m.def(
426 "_push_saved_tensors_default_hooks",
427 [](py::function& pack_hook, py::function& unpack_hook) {
428 torch::autograd::PyDefaultSavedVariableHooks::push_hooks(
429 pack_hook, unpack_hook);
430 });
431 m.def("_pop_saved_tensors_default_hooks", []() {
432 torch::autograd::PyDefaultSavedVariableHooks::pop_hooks();
433 });
434
435 m.def("_get_creation_meta", [](const at::Tensor& t) {
436 auto* meta = torch::autograd::impl::get_view_autograd_meta(t);
437 TORCH_CHECK(meta != nullptr);
438 return meta->get_creation_meta();
439 });
440
441 m.def(
442 "_set_creation_meta",
443 [](const at::Tensor& t, CreationMeta new_creation_meta) {
444 auto* meta = torch::autograd::impl::get_view_autograd_meta(t);
445 TORCH_CHECK(meta != nullptr);
446 meta->set_creation_meta(new_creation_meta);
447 });
448
449 m.def("_get_current_graph_task_keep_graph", []() {
450 return torch::autograd::get_current_graph_task_keep_graph();
451 });
452
453 m.def(
454 "_get_data_attr", [](const at::Tensor& t) { return t.variable_data(); });
455
456 _C_m.def(
457 "_register_py_class_for_device",
458 [](const std::string& device, py::object python_type_class) {
459 auto cls = python_type_class.ptr();
460 registerPythonTensorClass(device, cls);
461 });
462 _C_m.def("_set_autograd_fallback_mode", [](const std::string& mode) {
463 if (mode == "nothing") {
464 torch::autograd::setAutogradFallbackMode(
465 torch::autograd::AutogradFallbackMode::Nothing);
466 return;
467 }
468 if (mode == "warn") {
469 torch::autograd::setAutogradFallbackMode(
470 torch::autograd::AutogradFallbackMode::Warn);
471 return;
472 }
473 if (mode == "error") {
474 torch::autograd::setAutogradFallbackMode(
475 torch::autograd::AutogradFallbackMode::Error);
476 return;
477 }
478 TORCH_INTERNAL_ASSERT(false, "Unsupported AutogradFallbackMode: ", mode);
479 });
480 _C_m.def("_get_autograd_fallback_mode", []() {
481 auto mode = torch::autograd::getAutogradFallbackMode();
482 switch (mode) {
483 case torch::autograd::AutogradFallbackMode::Nothing:
484 return "nothing";
485 case torch::autograd::AutogradFallbackMode::Warn:
486 return "warn";
487 case torch::autograd::AutogradFallbackMode::Error:
488 return "error";
489 default:
490 TORCH_INTERNAL_ASSERT(false, "Unsupported AutogradFallbackMode");
491 }
492 });
493
494 _C_m.def("_activate_gpu_trace", []() { activateGPUTrace(); });
495
496 py_context_manager_DEPRECATED<c10::InferenceMode, bool>(
497 _C_m, "_InferenceMode");
498 py_context_manager<at::impl::RestorePythonTLSSnapshot>(
499 _C_m, "_RestorePythonTLSSnapshot");
500
501 py_context_manager_DEPRECATED<torch::DisableTorchDispatch>(
502 _C_m, "_DisableTorchDispatch");
503 py_context_manager_DEPRECATED<EnableTorchFunction>(
504 _C_m, "_EnableTorchFunction");
505 py_context_manager_DEPRECATED<EnablePythonDispatcher>(
506 _C_m, "_EnablePythonDispatcher");
507 py_context_manager<c10::impl::DisablePythonDispatcher>(
508 _C_m, "_DisablePythonDispatcher");
509 py_context_manager<EnablePreDispatch>(_C_m, "_EnablePreDispatch");
510 py_context_manager_DEPRECATED<DisableFuncTorch>(_C_m, "_DisableFuncTorch");
511 py_context_manager<DisableAutocast>(_C_m, "_DisableAutocast");
512 py::class_<torch::autograd::SavedVariable>(std::move(m), "SavedTensor")
513 .def(py::init([]() -> torch::autograd::SavedVariable {
514 TORCH_CHECK(
515 false,
516 "Trying to create a SavedTensor object from Python is forbidden.");
517 }))
518 .def(
519 "register_hooks",
520 [](torch::autograd::SavedVariable& s,
521 py::function& pack_hook,
522 py::function& unpack_hook) {
523 // Because we use a py::object, pybind will increment the refcount
524 // of the hook functions for us
525 s.register_hooks(
526 std::make_unique<torch::autograd::PySavedVariableHooks>(
527 pack_hook, unpack_hook));
528 });
529
530 torch::autograd::profiler::python_tracer::init();
531 Py_RETURN_TRUE;
532 }
533
534 namespace torch::autograd {
535
set_autocast_enabled(PyObject * _unused,PyObject * args,PyObject * kwargs)536 static PyObject* set_autocast_enabled(
537 PyObject* _unused,
538 PyObject* args,
539 PyObject* kwargs) {
540 HANDLE_TH_ERRORS
541 static PythonArgParser parser(
542 {"set_autocast_enabled(c10::string_view device_type, bool enabled)",
543 "set_autocast_enabled(bool enabled)"}); // this signature is depracated.
544 ParsedArgs<2> parsed_args;
545 auto r = parser.parse(args, kwargs, parsed_args);
546 // Set at::kCUDA as default value to prevent BC-breaking changes.
547 at::DeviceType device_type = at::kCUDA;
548 int enabled_id = 0;
549 if (r.idx == 0) {
550 device_type = at::Device(r.string(0)).type();
551 enabled_id = 1;
552 }
553 auto enabled = r.toBool(enabled_id);
554 at::autocast::set_autocast_enabled(device_type, enabled);
555 Py_RETURN_NONE;
556 END_HANDLE_TH_ERRORS
557 }
558
is_autocast_enabled(PyObject * _unused,PyObject * args,PyObject * kwargs)559 static PyObject* is_autocast_enabled(
560 PyObject* _unused,
561 PyObject* args,
562 PyObject* kwargs) {
563 HANDLE_TH_ERRORS
564 static PythonArgParser parser(
565 {"is_autocast_enabled(c10::string_view device_type)",
566 "is_autocast_enabled()"}); // this signature is depracated.
567 ParsedArgs<1> parsed_args;
568 auto r = parser.parse(args, kwargs, parsed_args);
569 // Set at::kCUDA as default value to prevent BC-breaking changes.
570 at::DeviceType device_type = at::kCUDA;
571 if (r.idx == 0) {
572 device_type = at::Device(r.string(0)).type();
573 }
574 if (at::autocast::is_autocast_enabled(device_type)) {
575 Py_RETURN_TRUE;
576 } else {
577 Py_RETURN_FALSE;
578 }
579 END_HANDLE_TH_ERRORS
580 }
581
get_autocast_dtype(PyObject * _unused,PyObject * args,PyObject * kwargs)582 static PyObject* get_autocast_dtype(
583 PyObject* _unused,
584 PyObject* args,
585 PyObject* kwargs) {
586 HANDLE_TH_ERRORS
587 static PythonArgParser parser(
588 {"get_autocast_dtype(c10::string_view device_type)"});
589 ParsedArgs<1> parsed_args;
590 auto r = parser.parse(args, kwargs, parsed_args);
591 auto device_type = at::Device(r.string(0)).type();
592 at::ScalarType current_dtype = at::autocast::get_autocast_dtype(device_type);
593 return utils::wrap(current_dtype);
594 END_HANDLE_TH_ERRORS
595 }
596
set_autocast_dtype(PyObject * _unused,PyObject * args,PyObject * kwargs)597 static PyObject* set_autocast_dtype(
598 PyObject* _unused,
599 PyObject* args,
600 PyObject* kwargs) {
601 HANDLE_TH_ERRORS
602 static PythonArgParser parser(
603 {"set_autocast_dtype(c10::string_view device_type, ScalarType dtype)"});
604 ParsedArgs<2> parsed_args;
605 auto r = parser.parse(args, kwargs, parsed_args);
606 auto device_type = at::Device(r.string(0)).type();
607 auto dtype = r.scalartype(1);
608 at::autocast::set_autocast_dtype(device_type, dtype);
609 Py_RETURN_NONE;
610 END_HANDLE_TH_ERRORS
611 }
612
is_any_autocast_enabled(PyObject * _unused,PyObject * arg)613 static PyObject* is_any_autocast_enabled(PyObject* _unused, PyObject* arg) {
614 HANDLE_TH_ERRORS
615 if (at::autocast::is_autocast_enabled(at::kCPU) ||
616 at::autocast::is_autocast_enabled(at::kCUDA) ||
617 at::autocast::is_autocast_enabled(at::kXPU) ||
618 at::autocast::is_autocast_enabled(at::kIPU) ||
619 at::autocast::is_autocast_enabled(at::kXLA) ||
620 at::autocast::is_autocast_enabled(at::kHPU) ||
621 at::autocast::is_autocast_enabled(at::kPrivateUse1)) {
622 Py_RETURN_TRUE;
623 } else {
624 Py_RETURN_FALSE;
625 }
626 END_HANDLE_TH_ERRORS
627 }
628
is_autocast_available(PyObject * _unused,PyObject * args,PyObject * kwargs)629 static PyObject* is_autocast_available(
630 PyObject* _unused,
631 PyObject* args,
632 PyObject* kwargs) {
633 HANDLE_TH_ERRORS
634 static PythonArgParser parser(
635 {"_is_autocast_available(c10::string_view device_type)"});
636 ParsedArgs<1> parsed_args;
637 auto r = parser.parse(args, kwargs, parsed_args);
638 auto device_type = at::Device(r.string(0)).type();
639 if (at::autocast::is_autocast_available(device_type)) {
640 Py_RETURN_TRUE;
641 } else {
642 Py_RETURN_FALSE;
643 }
644 END_HANDLE_TH_ERRORS
645 }
646
set_autocast_cpu_enabled(PyObject * _unused,PyObject * arg)647 static PyObject* set_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
648 HANDLE_TH_ERRORS
649 TORCH_CHECK_TYPE(
650 PyBool_Check(arg),
651 "enabled must be a bool (got ",
652 Py_TYPE(arg)->tp_name,
653 ")");
654 TORCH_WARN_DEPRECATION(
655 "torch.set_autocast_cpu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('cpu', enabled) instead.")
656 at::autocast::set_autocast_enabled(at::kCPU, arg == Py_True);
657 Py_RETURN_NONE;
658 END_HANDLE_TH_ERRORS
659 }
660
is_autocast_cpu_enabled(PyObject * _unused,PyObject * arg)661 static PyObject* is_autocast_cpu_enabled(PyObject* _unused, PyObject* arg) {
662 HANDLE_TH_ERRORS
663 TORCH_WARN_DEPRECATION(
664 "torch.is_autocast_cpu_enabled() is deprecated. Please use torch.is_autocast_enabled('cpu') instead.")
665 if (at::autocast::is_autocast_enabled(at::kCPU)) {
666 Py_RETURN_TRUE;
667 } else {
668 Py_RETURN_FALSE;
669 }
670 END_HANDLE_TH_ERRORS
671 }
672
set_autocast_ipu_enabled(PyObject * _unused,PyObject * arg)673 static PyObject* set_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
674 HANDLE_TH_ERRORS
675 TORCH_CHECK_TYPE(
676 PyBool_Check(arg),
677 "enabled must be a bool (got ",
678 Py_TYPE(arg)->tp_name,
679 ")");
680 TORCH_WARN_DEPRECATION(
681 "torch.set_autocast_ipu_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('ipu', enabled) instead.")
682 at::autocast::set_autocast_enabled(at::kIPU, arg == Py_True);
683 Py_RETURN_NONE;
684 END_HANDLE_TH_ERRORS
685 }
686
is_autocast_ipu_enabled(PyObject * _unused,PyObject * arg)687 static PyObject* is_autocast_ipu_enabled(PyObject* _unused, PyObject* arg) {
688 HANDLE_TH_ERRORS
689 TORCH_WARN_DEPRECATION(
690 "torch.is_autocast_ipu_enabled() is deprecated. Please use torch.is_autocast_enabled('ipu') instead.")
691 if (at::autocast::is_autocast_enabled(at::kIPU)) {
692 Py_RETURN_TRUE;
693 } else {
694 Py_RETURN_FALSE;
695 }
696 END_HANDLE_TH_ERRORS
697 }
698
set_autocast_xla_enabled(PyObject * _unused,PyObject * arg)699 static PyObject* set_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
700 HANDLE_TH_ERRORS
701 TORCH_CHECK_TYPE(
702 PyBool_Check(arg),
703 "enabled must be a bool (got ",
704 Py_TYPE(arg)->tp_name,
705 ")");
706 TORCH_WARN_DEPRECATION(
707 "torch.set_autocast_xla_enabled(enabled) is deprecated. Please use torch.set_autocast_enabled('xla', enabled) instead.")
708 at::autocast::set_autocast_enabled(at::kXLA, arg == Py_True);
709 Py_RETURN_NONE;
710 END_HANDLE_TH_ERRORS
711 }
712
is_autocast_xla_enabled(PyObject * _unused,PyObject * arg)713 static PyObject* is_autocast_xla_enabled(PyObject* _unused, PyObject* arg) {
714 HANDLE_TH_ERRORS
715 TORCH_WARN_DEPRECATION(
716 "torch.is_autocast_xla_enabled() is deprecated. Please use torch.is_autocast_enabled('xla') instead.")
717 if (at::autocast::is_autocast_enabled(at::kXLA)) {
718 Py_RETURN_TRUE;
719 } else {
720 Py_RETURN_FALSE;
721 }
722 END_HANDLE_TH_ERRORS
723 }
724
set_autocast_gpu_dtype(PyObject * _unused,PyObject * arg)725 static PyObject* set_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
726 HANDLE_TH_ERRORS
727 TORCH_CHECK_TYPE(
728 THPDtype_Check(arg),
729 "dtype must be a torch.dtype (got ",
730 Py_TYPE(arg)->tp_name,
731 ")");
732 TORCH_WARN_DEPRECATION(
733 "torch.set_autocast_gpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cuda', dtype) instead.")
734 at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
735 at::autocast::set_autocast_dtype(at::kCUDA, targetType);
736 Py_RETURN_NONE;
737 END_HANDLE_TH_ERRORS
738 }
739
set_autocast_cpu_dtype(PyObject * _unused,PyObject * arg)740 static PyObject* set_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
741 HANDLE_TH_ERRORS
742 TORCH_CHECK_TYPE(
743 THPDtype_Check(arg),
744 "dtype must be a torch.dtype (got ",
745 Py_TYPE(arg)->tp_name,
746 ")");
747 TORCH_WARN_DEPRECATION(
748 "torch.set_autocast_cpu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('cpu', dtype) instead.")
749 at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
750 at::autocast::set_autocast_dtype(at::kCPU, targetType);
751 Py_RETURN_NONE;
752 END_HANDLE_TH_ERRORS
753 }
754
set_autocast_ipu_dtype(PyObject * _unused,PyObject * arg)755 static PyObject* set_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
756 HANDLE_TH_ERRORS
757 TORCH_CHECK_TYPE(
758 THPDtype_Check(arg),
759 "dtype must be a torch.dtype (got ",
760 Py_TYPE(arg)->tp_name,
761 ")");
762 TORCH_WARN_DEPRECATION(
763 "torch.set_autocast_ipu_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('ipu', dtype) instead.")
764 at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
765 at::autocast::set_autocast_dtype(at::kIPU, targetType);
766 Py_RETURN_NONE;
767 END_HANDLE_TH_ERRORS
768 }
769
set_autocast_xla_dtype(PyObject * _unused,PyObject * arg)770 static PyObject* set_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
771 HANDLE_TH_ERRORS
772 TORCH_CHECK_TYPE(
773 THPDtype_Check(arg),
774 "dtype must be a torch.dtype (got ",
775 Py_TYPE(arg)->tp_name,
776 ")");
777 TORCH_WARN_DEPRECATION(
778 "torch.set_autocast_xla_dtype(dtype) is deprecated. Please use torch.set_autocast_dtype('xla', dtype) instead.")
779 at::ScalarType targetType = reinterpret_cast<THPDtype*>(arg)->scalar_type;
780 at::autocast::set_autocast_dtype(at::kXLA, targetType);
781 Py_RETURN_NONE;
782 END_HANDLE_TH_ERRORS
783 }
784
get_autocast_gpu_dtype(PyObject * _unused,PyObject * arg)785 static PyObject* get_autocast_gpu_dtype(PyObject* _unused, PyObject* arg) {
786 HANDLE_TH_ERRORS
787 TORCH_WARN_DEPRECATION(
788 "torch.get_autocast_gpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cuda') instead.")
789 at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCUDA);
790 return utils::wrap(current_dtype);
791 END_HANDLE_TH_ERRORS
792 }
793
get_autocast_cpu_dtype(PyObject * _unused,PyObject * arg)794 static PyObject* get_autocast_cpu_dtype(PyObject* _unused, PyObject* arg) {
795 HANDLE_TH_ERRORS
796 TORCH_WARN_DEPRECATION(
797 "torch.get_autocast_cpu_dtype() is deprecated. Please use torch.get_autocast_dtype('cpu') instead.")
798 at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kCPU);
799 return utils::wrap(current_dtype);
800 END_HANDLE_TH_ERRORS
801 }
802
get_autocast_ipu_dtype(PyObject * _unused,PyObject * arg)803 static PyObject* get_autocast_ipu_dtype(PyObject* _unused, PyObject* arg) {
804 HANDLE_TH_ERRORS
805 TORCH_WARN_DEPRECATION(
806 "torch.get_autocast_ipu_dtype() is deprecated. Please use torch.get_autocast_dtype('ipu') instead.")
807 at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kIPU);
808 return utils::wrap(current_dtype);
809 END_HANDLE_TH_ERRORS
810 }
811
get_autocast_xla_dtype(PyObject * _unused,PyObject * arg)812 static PyObject* get_autocast_xla_dtype(PyObject* _unused, PyObject* arg) {
813 HANDLE_TH_ERRORS
814 TORCH_WARN_DEPRECATION(
815 "torch.get_autocast_xla_dtype() is deprecated. Please use torch.get_autocast_dtype('xla') instead.")
816 at::ScalarType current_dtype = at::autocast::get_autocast_dtype(at::kXLA);
817 return utils::wrap(current_dtype);
818 END_HANDLE_TH_ERRORS
819 }
820
clear_autocast_cache(PyObject * _unused,PyObject * arg)821 static PyObject* clear_autocast_cache(PyObject* _unused, PyObject* arg) {
822 HANDLE_TH_ERRORS {
823 pybind11::gil_scoped_release no_gil;
824 at::autocast::clear_cache();
825 }
826 Py_RETURN_NONE;
827 END_HANDLE_TH_ERRORS
828 }
829
autocast_increment_nesting(PyObject * _unused,PyObject * arg)830 static PyObject* autocast_increment_nesting(PyObject* _unused, PyObject* arg) {
831 HANDLE_TH_ERRORS
832 return THPUtils_packInt64(at::autocast::increment_nesting());
833 END_HANDLE_TH_ERRORS
834 }
835
autocast_decrement_nesting(PyObject * _unused,PyObject * arg)836 static PyObject* autocast_decrement_nesting(PyObject* _unused, PyObject* arg) {
837 HANDLE_TH_ERRORS
838 return THPUtils_packInt64(at::autocast::decrement_nesting());
839 END_HANDLE_TH_ERRORS
840 }
841
is_autocast_cache_enabled(PyObject * _unused,PyObject * arg)842 static PyObject* is_autocast_cache_enabled(PyObject* _unused, PyObject* arg) {
843 HANDLE_TH_ERRORS
844 if (at::autocast::is_autocast_cache_enabled()) {
845 Py_RETURN_TRUE;
846 } else {
847 Py_RETURN_FALSE;
848 }
849 END_HANDLE_TH_ERRORS
850 }
851
set_autocast_cache_enabled(PyObject * _unused,PyObject * arg)852 static PyObject* set_autocast_cache_enabled(PyObject* _unused, PyObject* arg) {
853 HANDLE_TH_ERRORS
854 TORCH_CHECK_TYPE(
855 PyBool_Check(arg),
856 "enabled must be a bool (got ",
857 Py_TYPE(arg)->tp_name,
858 ")");
859 at::autocast::set_autocast_cache_enabled(arg == Py_True);
860 Py_RETURN_NONE;
861 END_HANDLE_TH_ERRORS
862 }
863
set_grad_enabled(PyObject * _unused,PyObject * args,PyObject * kwargs)864 static PyObject* set_grad_enabled(
865 PyObject* _unused,
866 PyObject* args,
867 PyObject* kwargs) {
868 HANDLE_TH_ERRORS
869 static PythonArgParser parser({
870 "set_grad_enabled(bool enabled)",
871 });
872 ParsedArgs<1> parsed_args;
873 auto r = parser.parse(args, kwargs, parsed_args);
874
875 if (at::impl::torch_function_mode_enabled()) {
876 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
877 return handle_torch_function(
878 r, args, kwargs, torch_C_module, "torch._C", "_set_grad_enabled");
879 }
880 auto grad_enabled = r.toBool(0);
881 GradMode::set_enabled(grad_enabled);
882 Py_RETURN_NONE;
883 END_HANDLE_TH_ERRORS
884 }
885
is_grad_enabled(PyObject * _unused,PyObject * arg)886 static PyObject* is_grad_enabled(PyObject* _unused, PyObject* arg) {
887 HANDLE_TH_ERRORS
888 if (GradMode::is_enabled()) {
889 Py_RETURN_TRUE;
890 } else {
891 Py_RETURN_FALSE;
892 }
893 END_HANDLE_TH_ERRORS
894 }
895
set_fwd_grad_enabled(PyObject * _unused,PyObject * arg)896 static PyObject* set_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
897 HANDLE_TH_ERRORS
898 TORCH_CHECK_TYPE(
899 PyBool_Check(arg),
900 "enabled must be a bool (got ",
901 Py_TYPE(arg)->tp_name,
902 ")");
903 c10::AutogradState::get_tls_state().set_fw_grad_mode(arg == Py_True);
904 Py_RETURN_NONE;
905 END_HANDLE_TH_ERRORS
906 }
907
is_fwd_grad_enabled(PyObject * _unused,PyObject * arg)908 static PyObject* is_fwd_grad_enabled(PyObject* _unused, PyObject* arg) {
909 HANDLE_TH_ERRORS
910 if (c10::AutogradState::get_tls_state().get_fw_grad_mode()) {
911 Py_RETURN_TRUE;
912 } else {
913 Py_RETURN_FALSE;
914 }
915 END_HANDLE_TH_ERRORS
916 }
917
set_multithreading_enabled(PyObject * self,PyObject * args,PyObject * kwargs)918 static PyObject* set_multithreading_enabled(
919 PyObject* self,
920 PyObject* args,
921 PyObject* kwargs) {
922 HANDLE_TH_ERRORS
923 static PythonArgParser parser({
924 "set_multithreading_enabled(bool enabled)",
925 });
926 ParsedArgs<1> parsed_args;
927 auto r = parser.parse(args, kwargs, parsed_args);
928
929 if (at::impl::torch_function_mode_enabled()) {
930 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
931 return handle_torch_function(
932 r,
933 args,
934 kwargs,
935 torch_C_module,
936 "torch._C",
937 "_set_multithreading_enabled");
938 }
939 auto multithreading_enabled = r.toBool(0);
940 c10::AutogradState::get_tls_state().set_multithreading_enabled(
941 multithreading_enabled);
942 Py_RETURN_NONE;
943 END_HANDLE_TH_ERRORS
944 }
945
is_multithreading_enabled(PyObject * self,PyObject * args)946 static PyObject* is_multithreading_enabled(PyObject* self, PyObject* args) {
947 HANDLE_TH_ERRORS
948 if (c10::AutogradState::get_tls_state().get_multithreading_enabled()) {
949 Py_RETURN_TRUE;
950 } else {
951 Py_RETURN_FALSE;
952 }
953 END_HANDLE_TH_ERRORS
954 }
955
set_view_replay_enabled(PyObject * self,PyObject * args,PyObject * kwargs)956 static PyObject* set_view_replay_enabled(
957 PyObject* self,
958 PyObject* args,
959 PyObject* kwargs) {
960 HANDLE_TH_ERRORS
961 static PythonArgParser parser({
962 "set_view_replay_enabled(bool enabled)",
963 });
964 ParsedArgs<1> parsed_args;
965 auto r = parser.parse(args, kwargs, parsed_args);
966
967 if (at::impl::torch_function_mode_enabled()) {
968 auto torch_C_module = THPObjectPtr(PyImport_ImportModule("torch._C"));
969 return handle_torch_function(
970 r,
971 args,
972 kwargs,
973 torch_C_module,
974 "torch._C",
975 "_set_view_replay_enabled");
976 }
977 auto view_replay_enabled = r.toBool(0);
978 c10::AutogradState::get_tls_state().set_view_replay_enabled(
979 view_replay_enabled);
980 Py_RETURN_NONE;
981 END_HANDLE_TH_ERRORS
982 }
983
is_view_replay_enabled(PyObject * self,PyObject * args)984 static PyObject* is_view_replay_enabled(PyObject* self, PyObject* args) {
985 HANDLE_TH_ERRORS
986 if (c10::AutogradState::get_tls_state().get_view_replay_enabled()) {
987 Py_RETURN_TRUE;
988 } else {
989 Py_RETURN_FALSE;
990 }
991 END_HANDLE_TH_ERRORS
992 }
993
is_inference_mode_enabled(PyObject * _unused,PyObject * arg)994 static PyObject* is_inference_mode_enabled(PyObject* _unused, PyObject* arg) {
995 HANDLE_TH_ERRORS
996 if (c10::InferenceMode::is_enabled()) {
997 Py_RETURN_TRUE;
998 } else {
999 Py_RETURN_FALSE;
1000 }
1001 END_HANDLE_TH_ERRORS
1002 }
1003
set_anomaly_mode_enabled(PyObject * _unused,PyObject * args,PyObject * kwargs)1004 static PyObject* set_anomaly_mode_enabled(
1005 PyObject* _unused,
1006 PyObject* args,
1007 PyObject* kwargs) {
1008 HANDLE_TH_ERRORS
1009 static PythonArgParser parser({
1010 "set_anomaly_enabled(bool enabled, bool check_nan=True)",
1011 });
1012 ParsedArgs<2> parsed_args;
1013 auto r = parser.parse(args, kwargs, parsed_args);
1014 AnomalyMode::set_enabled(r.toBool(0), r.toBool(1));
1015 Py_RETURN_NONE;
1016 END_HANDLE_TH_ERRORS
1017 }
1018
is_anomaly_mode_enabled(PyObject * _unused,PyObject * arg)1019 static PyObject* is_anomaly_mode_enabled(PyObject* _unused, PyObject* arg) {
1020 HANDLE_TH_ERRORS
1021 if (AnomalyMode::is_enabled()) {
1022 Py_RETURN_TRUE;
1023 } else {
1024 Py_RETURN_FALSE;
1025 }
1026 END_HANDLE_TH_ERRORS
1027 }
1028
is_anomaly_check_nan_enabled(PyObject * _unused,PyObject * arg)1029 static PyObject* is_anomaly_check_nan_enabled(
1030 PyObject* _unused,
1031 PyObject* arg) {
1032 HANDLE_TH_ERRORS
1033 if (AnomalyMode::should_check_nan()) {
1034 Py_RETURN_TRUE;
1035 } else {
1036 Py_RETURN_FALSE;
1037 }
1038 END_HANDLE_TH_ERRORS
1039 }
1040
python_enter_dual_level(PyObject * _unused,PyObject * arg)1041 static PyObject* python_enter_dual_level(PyObject* _unused, PyObject* arg) {
1042 HANDLE_TH_ERRORS
1043 // It is unlikely that the depth of forward nesting will overflow int64_t so
1044 // we just static cast here.
1045 return utils::wrap(static_cast<int64_t>(forward_ad::enter_dual_level()));
1046 END_HANDLE_TH_ERRORS
1047 }
1048
python_exit_dual_level(PyObject * _unused,PyObject * args,PyObject * kwargs)1049 static PyObject* python_exit_dual_level(
1050 PyObject* _unused,
1051 PyObject* args,
1052 PyObject* kwargs) {
1053 HANDLE_TH_ERRORS
1054 static PythonArgParser parser({"exit_dual_level(int64_t level)"});
1055
1056 ParsedArgs<1> parsed_args;
1057 auto _r = parser.parse(args, kwargs, parsed_args);
1058
1059 auto idx = _r.toInt64(0);
1060 // Make sure the given index is valid before casting it
1061 TORCH_CHECK(idx >= 0, "Dual level must be a positive number.");
1062 forward_ad::exit_dual_level(static_cast<uint64_t>(idx));
1063 Py_RETURN_NONE;
1064 END_HANDLE_TH_ERRORS
1065 }
1066
is_torch_function_mode_enabled(PyObject * _unused,PyObject * _unused2)1067 static PyObject* is_torch_function_mode_enabled(
1068 PyObject* _unused,
1069 PyObject* _unused2) {
1070 HANDLE_TH_ERRORS
1071 if (at::impl::torch_function_mode_enabled()) {
1072 Py_RETURN_TRUE;
1073 } else {
1074 Py_RETURN_FALSE;
1075 }
1076 END_HANDLE_TH_ERRORS
1077 }
1078
push_on_torch_function_stack(PyObject * _unused,PyObject * arg)1079 static PyObject* push_on_torch_function_stack(
1080 PyObject* _unused,
1081 PyObject* arg) {
1082 HANDLE_TH_ERRORS
1083 if (arg != Py_None) {
1084 Py_INCREF(arg);
1085 at::impl::PythonTorchFunctionTLS::push_onto_stack(
1086 std::make_shared<c10::SafePyObject>(arg, getPyInterpreter()));
1087 }
1088 Py_RETURN_NONE;
1089 END_HANDLE_TH_ERRORS
1090 }
1091
pop_torch_function_stack(PyObject * _unused,PyObject * _unused2)1092 static PyObject* pop_torch_function_stack(
1093 PyObject* _unused,
1094 PyObject* _unused2) {
1095 HANDLE_TH_ERRORS
1096 const auto& mode = at::impl::PythonTorchFunctionTLS::pop_stack();
1097 auto* r = mode->ptr(getPyInterpreter());
1098 Py_INCREF(r);
1099 return r;
1100 END_HANDLE_TH_ERRORS
1101 }
1102
get_function_stack_at(PyObject * _unused,PyObject * args,PyObject * kwargs)1103 static PyObject* get_function_stack_at(
1104 PyObject* _unused,
1105 PyObject* args,
1106 PyObject* kwargs) {
1107 HANDLE_TH_ERRORS
1108 static PythonArgParser parser({"get_stack_at(int64_t level)"});
1109
1110 ParsedArgs<1> parsed_args;
1111 auto _r = parser.parse(args, kwargs, parsed_args);
1112
1113 auto idx = _r.toInt64(0);
1114 const auto& mode = at::impl::PythonTorchFunctionTLS::get_stack_at(idx);
1115 auto* r = mode->ptr(getPyInterpreter());
1116 Py_INCREF(r);
1117 return r;
1118 END_HANDLE_TH_ERRORS
1119 }
1120
len_torch_function_stack(PyObject * _unused,PyObject * _unused2)1121 static PyObject* len_torch_function_stack(
1122 PyObject* _unused,
1123 PyObject* _unused2) {
1124 HANDLE_TH_ERRORS
1125 const auto len = at::impl::PythonTorchFunctionTLS::stack_len();
1126 return utils::wrap(static_cast<int64_t>(len));
1127 END_HANDLE_TH_ERRORS
1128 }
1129
push_on_torch_dispatch_stack(PyObject * _unused,PyObject * arg)1130 static PyObject* push_on_torch_dispatch_stack(
1131 PyObject* _unused,
1132 PyObject* arg) {
1133 HANDLE_TH_ERRORS
1134 if (arg != Py_None) {
1135 using c10::impl::TorchDispatchModeKey;
1136 // When we push a mode onto the mode stack, we need to
1137 // check if it's an "infra" mode, by checking its _mode_key attribute.
1138 std::optional<c10::impl::TorchDispatchModeKey> mode_key = std::nullopt;
1139 py::object maybe_mode_key_obj =
1140 PyObject_FastGetAttrString(arg, "_mode_key");
1141 if (maybe_mode_key_obj) {
1142 mode_key = py::cast<c10::impl::TorchDispatchModeKey>(maybe_mode_key_obj);
1143 c10::impl::TorchDispatchModeTLS::set_mode(
1144 std::make_shared<c10::impl::PyObject_TorchDispatchMode>(
1145 arg, getPyInterpreter()),
1146 mode_key.value());
1147 } else {
1148 c10::impl::TorchDispatchModeTLS::push_non_infra_mode_onto_stack(
1149 std::make_shared<c10::impl::PyObject_TorchDispatchMode>(
1150 arg, getPyInterpreter()));
1151 }
1152 Py_INCREF(arg);
1153 }
1154 Py_RETURN_NONE;
1155 END_HANDLE_TH_ERRORS
1156 }
1157
pop_torch_dispatch_stack(PyObject * _unused,PyObject * maybe_mode_key)1158 static PyObject* pop_torch_dispatch_stack(
1159 PyObject* _unused,
1160 PyObject* maybe_mode_key) {
1161 HANDLE_TH_ERRORS
1162 std::optional<c10::impl::TorchDispatchModeKey> mode_key = std::nullopt;
1163 PyObject* r = nullptr;
1164 if (maybe_mode_key != Py_None) {
1165 mode_key = py::cast<c10::impl::TorchDispatchModeKey>(maybe_mode_key);
1166 auto maybe_mode =
1167 c10::impl::TorchDispatchModeTLS::unset_mode(mode_key.value());
1168 TORCH_CHECK(
1169 maybe_mode.has_value(),
1170 "Attempted to unset ",
1171 c10::impl::to_string(mode_key.value()),
1172 ", but there wasn't one active.");
1173 auto mode = maybe_mode.value();
1174 r = mode->ptr(getPyInterpreter());
1175 } else {
1176 auto mode = c10::impl::TorchDispatchModeTLS::pop_stack();
1177 r = mode->ptr(getPyInterpreter());
1178 }
1179 Py_INCREF(r);
1180 return r;
1181 END_HANDLE_TH_ERRORS
1182 }
1183
get_dispatch_stack_at(PyObject * _unused,PyObject * args,PyObject * kwargs)1184 static PyObject* get_dispatch_stack_at(
1185 PyObject* _unused,
1186 PyObject* args,
1187 PyObject* kwargs) {
1188 HANDLE_TH_ERRORS
1189 static PythonArgParser parser({"get_stack_at(int64_t level)"});
1190
1191 ParsedArgs<1> parsed_args;
1192 auto _r = parser.parse(args, kwargs, parsed_args);
1193
1194 auto idx = _r.toInt64(0);
1195 const auto& mode = c10::impl::TorchDispatchModeTLS::get_stack_at(idx);
1196 auto* r = mode->ptr(getPyInterpreter());
1197 Py_INCREF(r);
1198 return r;
1199 END_HANDLE_TH_ERRORS
1200 }
1201
set_dispatch_mode(PyObject * _unused,PyObject * mode)1202 static PyObject* set_dispatch_mode(PyObject* _unused, PyObject* mode) {
1203 HANDLE_TH_ERRORS
1204 TORCH_CHECK(mode != Py_None);
1205
1206 py::object maybe_mode_key_obj = PyObject_FastGetAttrString(mode, "_mode_key");
1207 TORCH_CHECK(
1208 maybe_mode_key_obj,
1209 "set_dispatch_mode() called with a mode that does not contain a _mode_key attribute!");
1210 auto mode_key = py::cast<c10::impl::TorchDispatchModeKey>(maybe_mode_key_obj);
1211
1212 Py_INCREF(mode);
1213 c10::impl::TorchDispatchModeTLS::set_mode(
1214 std::make_shared<c10::impl::PyObject_TorchDispatchMode>(
1215 mode, getPyInterpreter()),
1216 mode_key);
1217
1218 Py_RETURN_NONE;
1219 END_HANDLE_TH_ERRORS
1220 }
1221
get_dispatch_mode(PyObject * _unused,PyObject * arg)1222 static PyObject* get_dispatch_mode(PyObject* _unused, PyObject* arg) {
1223 HANDLE_TH_ERRORS
1224 TORCH_CHECK(arg != Py_None);
1225 auto mode_key = py::cast<c10::impl::TorchDispatchModeKey>(arg);
1226
1227 auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(mode_key);
1228 if (maybe_mode == std::nullopt) {
1229 Py_RETURN_NONE;
1230 }
1231 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1232 auto* r = maybe_mode.value()->ptr(getPyInterpreter());
1233 Py_INCREF(r);
1234 return r;
1235 END_HANDLE_TH_ERRORS
1236 }
1237
unset_dispatch_mode(PyObject * _unused,PyObject * arg)1238 static PyObject* unset_dispatch_mode(PyObject* _unused, PyObject* arg) {
1239 HANDLE_TH_ERRORS
1240 TORCH_CHECK(arg != Py_None);
1241 auto mode_key = py::cast<c10::impl::TorchDispatchModeKey>(arg);
1242
1243 const auto maybe_mode = c10::impl::TorchDispatchModeTLS::unset_mode(mode_key);
1244 if (maybe_mode == std::nullopt) {
1245 Py_RETURN_NONE;
1246 }
1247 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
1248 auto* r = maybe_mode.value()->ptr(getPyInterpreter());
1249 Py_INCREF(r);
1250 return r;
1251 END_HANDLE_TH_ERRORS
1252 }
1253
len_torch_dispatch_stack(PyObject * _unused,PyObject * args)1254 static PyObject* len_torch_dispatch_stack(PyObject* _unused, PyObject* args) {
1255 HANDLE_TH_ERRORS
1256 const auto len = c10::impl::TorchDispatchModeTLS::stack_len();
1257 return utils::wrap(static_cast<int64_t>(len));
1258 END_HANDLE_TH_ERRORS
1259 }
1260
THPModule_increment_version(PyObject * _unused,PyObject * tensor_list)1261 PyObject* THPModule_increment_version(
1262 PyObject* _unused,
1263 PyObject* tensor_list) {
1264 HANDLE_TH_ERRORS
1265 auto iterator = THPObjectPtr(PyObject_GetIter(tensor_list));
1266 TORCH_CHECK(iterator, "increment_version expect a Iterable[Tensor] as input");
1267 auto item = THPObjectPtr(PyIter_Next(iterator));
1268 while (item) {
1269 TORCH_CHECK(
1270 THPVariable_Check(item),
1271 "increment_version expects each element of the iterable to be a tensor");
1272 auto t = THPVariable_Unpack(item);
1273 if (!t.is_inference()) {
1274 torch::autograd::increment_version(t);
1275 }
1276 item = THPObjectPtr(PyIter_Next(iterator));
1277 }
1278 Py_RETURN_NONE;
1279 END_HANDLE_TH_ERRORS
1280 }
1281
1282 // autograd methods on torch._C
1283 static PyMethodDef methods[] = {
1284 {"_set_grad_enabled",
1285 castPyCFunctionWithKeywords(set_grad_enabled),
1286 METH_VARARGS | METH_KEYWORDS,
1287 nullptr},
1288 {"is_grad_enabled", is_grad_enabled, METH_NOARGS, nullptr},
1289 {"_set_fwd_grad_enabled", set_fwd_grad_enabled, METH_O, nullptr},
1290 {"_is_fwd_grad_enabled", is_fwd_grad_enabled, METH_NOARGS, nullptr},
1291 {"is_inference_mode_enabled",
1292 is_inference_mode_enabled,
1293 METH_NOARGS,
1294 nullptr},
1295 {"set_autocast_enabled",
1296 castPyCFunctionWithKeywords(set_autocast_enabled),
1297 METH_VARARGS | METH_KEYWORDS,
1298 nullptr},
1299 {"is_autocast_enabled",
1300 castPyCFunctionWithKeywords(is_autocast_enabled),
1301 METH_VARARGS | METH_KEYWORDS,
1302 nullptr},
1303 {"set_autocast_dtype",
1304 castPyCFunctionWithKeywords(set_autocast_dtype),
1305 METH_VARARGS | METH_KEYWORDS,
1306 nullptr},
1307 {"get_autocast_dtype",
1308 castPyCFunctionWithKeywords(get_autocast_dtype),
1309 METH_VARARGS | METH_KEYWORDS,
1310 nullptr},
1311 {"_is_any_autocast_enabled", is_any_autocast_enabled, METH_NOARGS, nullptr},
1312 {"_is_autocast_available",
1313 castPyCFunctionWithKeywords(is_autocast_available),
1314 METH_VARARGS | METH_KEYWORDS,
1315 nullptr},
1316 {"clear_autocast_cache", clear_autocast_cache, METH_NOARGS, nullptr},
1317 {"set_autocast_cpu_enabled", set_autocast_cpu_enabled, METH_O, nullptr},
1318 {"is_autocast_cpu_enabled", is_autocast_cpu_enabled, METH_NOARGS, nullptr},
1319 {"set_autocast_cpu_dtype", set_autocast_cpu_dtype, METH_O, nullptr},
1320 {"get_autocast_cpu_dtype", get_autocast_cpu_dtype, METH_NOARGS, nullptr},
1321 {"set_autocast_gpu_dtype", set_autocast_gpu_dtype, METH_O, nullptr},
1322 {"get_autocast_gpu_dtype", get_autocast_gpu_dtype, METH_NOARGS, nullptr},
1323 {"set_autocast_xla_enabled", set_autocast_xla_enabled, METH_O, nullptr},
1324 {"is_autocast_xla_enabled", is_autocast_xla_enabled, METH_NOARGS, nullptr},
1325 {"set_autocast_xla_dtype", set_autocast_xla_dtype, METH_O, nullptr},
1326 {"get_autocast_xla_dtype", get_autocast_xla_dtype, METH_NOARGS, nullptr},
1327 {"set_autocast_ipu_enabled", set_autocast_ipu_enabled, METH_O, nullptr},
1328 {"is_autocast_ipu_enabled", is_autocast_ipu_enabled, METH_NOARGS, nullptr},
1329 {"set_autocast_ipu_dtype", set_autocast_ipu_dtype, METH_O, nullptr},
1330 {"get_autocast_ipu_dtype", get_autocast_ipu_dtype, METH_NOARGS, nullptr},
1331 {"autocast_increment_nesting",
1332 autocast_increment_nesting,
1333 METH_NOARGS,
1334 nullptr},
1335 {"autocast_decrement_nesting",
1336 autocast_decrement_nesting,
1337 METH_NOARGS,
1338 nullptr},
1339 {"is_autocast_cache_enabled",
1340 is_autocast_cache_enabled,
1341 METH_NOARGS,
1342 nullptr},
1343 {"set_autocast_cache_enabled", set_autocast_cache_enabled, METH_O, nullptr},
1344 {"_increment_version", THPModule_increment_version, METH_O, nullptr},
1345 {"set_anomaly_enabled",
1346 castPyCFunctionWithKeywords(set_anomaly_mode_enabled),
1347 METH_VARARGS | METH_KEYWORDS,
1348 nullptr},
1349 {"is_anomaly_enabled", is_anomaly_mode_enabled, METH_NOARGS, nullptr},
1350 {"is_anomaly_check_nan_enabled",
1351 is_anomaly_check_nan_enabled,
1352 METH_NOARGS,
1353 nullptr},
1354 {"_is_multithreading_enabled",
1355 is_multithreading_enabled,
1356 METH_NOARGS,
1357 nullptr},
1358 {"_set_multithreading_enabled",
1359 castPyCFunctionWithKeywords(set_multithreading_enabled),
1360 METH_VARARGS | METH_KEYWORDS,
1361 nullptr},
1362 {"_is_view_replay_enabled", is_view_replay_enabled, METH_NOARGS, nullptr},
1363 {"_set_view_replay_enabled",
1364 castPyCFunctionWithKeywords(set_view_replay_enabled),
1365 METH_VARARGS | METH_KEYWORDS,
1366 nullptr},
1367 {"_enter_dual_level", python_enter_dual_level, METH_NOARGS, nullptr},
1368 {"_exit_dual_level",
1369 castPyCFunctionWithKeywords(python_exit_dual_level),
1370 METH_VARARGS | METH_KEYWORDS,
1371 nullptr},
1372 {"_is_torch_function_mode_enabled",
1373 is_torch_function_mode_enabled,
1374 METH_NOARGS,
1375 nullptr},
1376 {"_push_on_torch_function_stack",
1377 push_on_torch_function_stack,
1378 METH_O,
1379 nullptr},
1380 {"_pop_torch_function_stack",
1381 pop_torch_function_stack,
1382 METH_NOARGS,
1383 nullptr},
1384 {"_get_function_stack_at",
1385 castPyCFunctionWithKeywords(get_function_stack_at),
1386 METH_VARARGS | METH_KEYWORDS,
1387 nullptr},
1388 {"_len_torch_function_stack",
1389 len_torch_function_stack,
1390 METH_NOARGS,
1391 nullptr},
1392 {"_push_on_torch_dispatch_stack",
1393 push_on_torch_dispatch_stack,
1394 METH_O,
1395 nullptr},
1396 {"_pop_torch_dispatch_stack", pop_torch_dispatch_stack, METH_O, nullptr},
1397 {"_get_dispatch_stack_at",
1398 castPyCFunctionWithKeywords(get_dispatch_stack_at),
1399 METH_VARARGS | METH_KEYWORDS,
1400 nullptr},
1401 {"_len_torch_dispatch_stack",
1402 len_torch_dispatch_stack,
1403 METH_NOARGS,
1404 nullptr},
1405 {"_set_dispatch_mode", set_dispatch_mode, METH_O, nullptr},
1406 {"_get_dispatch_mode", get_dispatch_mode, METH_O, nullptr},
1407 {"_unset_dispatch_mode", unset_dispatch_mode, METH_O, nullptr},
1408
1409 {nullptr, nullptr, 0, nullptr}};
1410
python_functions()1411 PyMethodDef* python_functions() {
1412 return methods;
1413 }
1414
1415 } // namespace torch::autograd
1416