1 #include <torch/csrc/profiler/python/init.h>
2
3 #include <ATen/record_function.h>
4 #include <c10/core/impl/PyInterpreter.h>
5 #include <c10/util/overloaded.h>
6 #include <torch/csrc/DynamicTypes.h>
7 #include <torch/csrc/autograd/utils/wrap_outputs.h>
8 #include <torch/csrc/jit/python/pybind_utils.h>
9 #include <torch/csrc/profiler/collection.h>
10 #include <torch/csrc/profiler/python/combined_traceback.h>
11 #include <torch/csrc/profiler/standalone/execution_trace_observer.h>
12 #include <torch/csrc/utils/pybind.h>
13
14 struct THPCapturedTraceback {
15 PyObject_HEAD std::shared_ptr<torch::CapturedTraceback> data;
16 };
17
THPCapturedTraceback_traverse(PyObject * self,visitproc visit,void * arg)18 static int THPCapturedTraceback_traverse(
19 PyObject* self,
20 visitproc visit,
21 void* arg) {
22 return ((THPCapturedTraceback*)self)
23 ->data->traversePython((int (*)(void*, void*))visit, arg);
24 }
25
THPCapturedTraceback_clear(PyObject * self)26 static int THPCapturedTraceback_clear(PyObject* self) {
27 return ((THPCapturedTraceback*)self)->data->clearPython();
28 }
29
THPCapturedTraceback_dealloc(PyObject * self_)30 static void THPCapturedTraceback_dealloc(PyObject* self_) {
31 auto* self = (THPCapturedTraceback*)self_;
32 PyObject_GC_UnTrack(self);
33 self->data.~shared_ptr<torch::CapturedTraceback>();
34 // promptly trigger delayed frees since we have GIL
35 torch::freeDeadCapturedTracebackFrames();
36 PyObject_GC_Del(self);
37 }
38
39 PyTypeObject THPCapturedTracebackType = {
40 PyVarObject_HEAD_INIT(
41 nullptr,
42 0) "torch._C._profiler.CapturedTraceback", /* tp_name */
43 sizeof(THPCapturedTraceback), /* tp_basicsize */
44 0, /* tp_itemsize */
45 THPCapturedTraceback_dealloc, /* tp_dealloc */
46 0, /* tp_vectorcall_offset */
47 nullptr, /* tp_getattr */
48 nullptr, /* tp_setattr */
49 nullptr, /* tp_reserved */
50 nullptr, /* tp_repr */
51 nullptr, /* tp_as_number */
52 nullptr, /* tp_as_sequence */
53 nullptr, /* tp_as_mapping */
54 nullptr, /* tp_hash */
55 nullptr, /* tp_call */
56 nullptr, /* tp_str */
57 nullptr, /* tp_getattro */
58 nullptr, /* tp_setattro */
59 nullptr, /* tp_as_buffer */
60 // NOLINTNEXTLINE(misc-redundant-expression)
61 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC, /* tp_flags */
62 nullptr, /* tp_doc */
63 (traverseproc)THPCapturedTraceback_traverse, /* tp_traverse */
64 (inquiry)THPCapturedTraceback_clear, /* tp_clear */
65 nullptr, /* tp_richcompare */
66 0, /* tp_weaklistoffset */
67 nullptr, /* tp_iter */
68 nullptr, /* tp_iternext */
69 nullptr, /* tp_methods */
70 nullptr, /* tp_members */
71 nullptr, /* tp_getset */
72 nullptr, /* tp_base */
73 nullptr, /* tp_dict */
74 nullptr, /* tp_descr_get */
75 nullptr, /* tp_descr_set */
76 0, /* tp_dictoffset */
77 nullptr, /* tp_init */
78 nullptr, /* tp_alloc */
79 nullptr, /* tp_new */
80 };
81
82 namespace pybind11::detail {
83
84 template <>
85 struct type_caster<std::shared_ptr<torch::CapturedTraceback>> {
86 public:
87 PYBIND11_TYPE_CASTER(
88 std::shared_ptr<torch::CapturedTraceback>,
89 _("torch._C._profiler.CapturedTraceback"));
90
loadpybind11::detail::type_caster91 bool load(handle src, bool) {
92 if (Py_TYPE(src.ptr()) == &THPCapturedTracebackType) {
93 value = reinterpret_cast<THPCapturedTraceback*>(src.ptr())->data;
94 return true;
95 }
96 return false;
97 }
98
castpybind11::detail::type_caster99 static handle cast(
100 std::shared_ptr<torch::CapturedTraceback> src,
101 return_value_policy /* policy */,
102 handle /* parent */) {
103 auto* r = PyObject_GC_New(THPCapturedTraceback, &THPCapturedTracebackType);
104 new (&r->data) std::shared_ptr<torch::CapturedTraceback>(std::move(src));
105 return py::handle((PyObject*)r);
106 }
107 };
108
109 } // namespace pybind11::detail
110
111 namespace torch::profiler {
112
113 /* [NOTE: RecordFunctionFast]
114 * This is an alternate way to call record_function from python.
115 * The torch.profiler.record_function context manager is slow (~14us on
116 * benchmarks in Aug 2023), which is usually fine for module-level annotations
117 * in python, but slow for per-op annotations. Part of the reason it is slow is
118 * because the calls go through the dispatcher, in order to make the
119 * record_function calls work with torchscript.
120 *
121 * This implementation doesn't go through the dispatcher and so it won't work
122 * with any feature relying on the dispatcher (e.g. torchscript or
123 * torch.compile)
124 *
125 * An alternate solution would be to implement a python context manager that
126 * calls into C++ for the enter/exit function:
127 * @contextlib.contextmanager
128 * def record_function_fast(name):
129 * rf = torch._C._record_function_fast_enter(name)
130 * try:
131 * yield
132 * finally:
133 * torch._C._record_function_fast_exit(rf)
134 * The C++ implementation here is faster by ~0.2-0.4us per context manager.
135 */
136
137 namespace {
138 struct RecordFunctionFast {
139 PyObject_HEAD PyObject* name;
140 PyObject* input_values;
141 PyObject* keyword_values;
142 std::unique_ptr<at::RecordFunction> guard;
143 };
144
RecordFunctionFast_new(PyTypeObject * subtype,PyObject * args,PyObject * kwargs)145 PyObject* RecordFunctionFast_new(
146 PyTypeObject* subtype,
147 PyObject* args,
148 PyObject* kwargs) {
149 RecordFunctionFast* self = (RecordFunctionFast*)subtype->tp_alloc(subtype, 0);
150 if (self != nullptr) {
151 self->name = nullptr;
152 self->input_values = nullptr;
153 self->keyword_values = nullptr;
154 self->guard.reset();
155 }
156 return (PyObject*)self;
157 }
158
RecordFunctionFast_init(PyObject * selfGeneric,PyObject * args,PyObject * kwargs)159 int RecordFunctionFast_init(
160 PyObject* selfGeneric,
161 PyObject* args,
162 PyObject* kwargs) {
163 auto self = (RecordFunctionFast*)selfGeneric;
164 // NOLINTNEXTLINE(*-c-arrays*)
165 constexpr const char* kwlist[] = {
166 "name", "input_values", "keyword_values", nullptr};
167 PyObject* name = nullptr;
168 PyObject* input_values = nullptr;
169 PyObject* keyword_values = nullptr;
170 if (!PyArg_ParseTupleAndKeywords(
171 args,
172 kwargs,
173 "O|OO", // name is required PyObject, args and kwargs are optional
174 // PyObjects
175 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
176 const_cast<char**>(kwlist),
177 &name,
178 &input_values,
179 &keyword_values)) {
180 return -1;
181 }
182 if (name) {
183 TORCH_CHECK(
184 THPUtils_checkString(name),
185 "The name passed to RecordFunctionFast must be a string");
186 Py_INCREF(name);
187 self->name = name;
188 }
189 if (input_values) {
190 TORCH_CHECK(
191 PyList_Check(input_values) || PyTuple_Check(input_values),
192 "input_values must be a list or tuple");
193 Py_INCREF(input_values);
194 self->input_values = input_values;
195 }
196 if (keyword_values) {
197 TORCH_CHECK(PyDict_Check(keyword_values), "keyword_values must be dict");
198 Py_INCREF(keyword_values);
199 self->keyword_values = keyword_values;
200 }
201 return 0;
202 }
203
RecordFunctionFast_dealloc(PyObject * selfGeneric)204 void RecordFunctionFast_dealloc(PyObject* selfGeneric) {
205 auto self = (RecordFunctionFast*)selfGeneric;
206 Py_CLEAR(self->name);
207 Py_CLEAR(self->input_values);
208 Py_CLEAR(self->keyword_values);
209 if (self->guard) {
210 self->guard.reset();
211 }
212 Py_TYPE(self)->tp_free(self);
213 }
214
RecordFunctionFast_enter(PyObject * selfGeneric,PyObject * unused)215 PyObject* RecordFunctionFast_enter(PyObject* selfGeneric, PyObject* unused) {
216 HANDLE_TH_ERRORS
217 if (torch::profiler::impl::ProfilerStateBase::get() != nullptr) {
218 auto self = (RecordFunctionFast*)selfGeneric;
219 TORCH_INTERNAL_ASSERT(
220 !self->guard,
221 "Trying to enter a new record_function_fast context but the guard is unexpectedly already set");
222 self->guard =
223 std::make_unique<at::RecordFunction>(at::RecordScope::FUNCTION);
224 std::vector<at::IValue> args;
225 std::unordered_map<std::string, at::IValue> kwargs;
226 bool profiler_need_input = torch::autograd::profiler::profilerEnabled() &&
227 torch::autograd::profiler::getProfilerConfig().report_input_shapes;
228 // parse through args if they exist
229 if (self->input_values != nullptr && profiler_need_input) {
230 THPObjectPtr input_fast(
231 PySequence_Fast(self->input_values, "input must be a sequence"));
232 PyObject** input_items = PySequence_Fast_ITEMS(input_fast.get());
233 for (int i = 0; i < PySequence_Fast_GET_SIZE(input_fast.get()); i++) {
234 PyObject* item = input_items[i];
235 auto match = torch::jit::tryToInferType(item);
236 if (match.success()) {
237 args.push_back(torch::jit::toIValue(item, match.type()));
238 }
239 }
240 }
241
242 // parse through kwargs if they exist
243 if (self->keyword_values != nullptr && profiler_need_input) {
244 Py_ssize_t pos = 0;
245 PyObject *key = nullptr, *value = nullptr;
246 while (PyDict_Next(self->keyword_values, &pos, &key, &value)) {
247 // Get the string representation of the key and value
248 std::string key_str = THPUtils_unpackString(key);
249 at::IValue ivalue;
250 if (THPUtils_checkString(value)) {
251 ivalue = at::IValue(THPUtils_unpackString(value));
252 } else {
253 auto match = torch::jit::tryToInferPrimitiveType(value);
254 if (match.success()) {
255 ivalue = torch::jit::toIValue(value, match.type());
256 } else {
257 TORCH_WARN("Unable to infer type of value for keyword: ", key_str);
258 ivalue = at::IValue("NULL");
259 }
260 }
261 kwargs[key_str] = ivalue;
262 }
263 }
264 self->guard->before(THPUtils_unpackString(self->name), &args, &kwargs);
265 }
266 Py_RETURN_NONE;
267 END_HANDLE_TH_ERRORS
268 }
269
RecordFunctionFast_exit(PyObject * selfGeneric,PyObject * unused)270 PyObject* RecordFunctionFast_exit(PyObject* selfGeneric, PyObject* unused) {
271 HANDLE_TH_ERRORS
272 if (torch::profiler::impl::ProfilerStateBase::get() != nullptr) {
273 auto self = (RecordFunctionFast*)selfGeneric;
274 TORCH_INTERNAL_ASSERT(
275 self->guard,
276 "Trying to exit an active record_function_fast context but no guard is set");
277 self->guard.reset();
278 }
279 Py_RETURN_NONE;
280 END_HANDLE_TH_ERRORS
281 }
282 } // namespace
283
initPythonBindings(PyObject * module)284 void initPythonBindings(PyObject* module) {
285 auto rootModule = py::handle(module).cast<py::module>();
286 auto m = rootModule.def_submodule("_profiler");
287
288 using namespace torch::profiler::impl;
289
290 py::enum_<at::RecordScope>(m, "RecordScope")
291 .value("FUNCTION", at::RecordScope::FUNCTION)
292 .value("BACKWARD_FUNCTION", at::RecordScope::BACKWARD_FUNCTION)
293 .value("TORCHSCRIPT_FUNCTION", at::RecordScope::TORCHSCRIPT_FUNCTION)
294 .value("KERNEL_FUNCTION_DTYPE", at::RecordScope::KERNEL_FUNCTION_DTYPE)
295 .value("CUSTOM_CLASS", at::RecordScope::CUSTOM_CLASS)
296 .value("BUILD_FEATURE", at::RecordScope::BUILD_FEATURE)
297 .value("LITE_INTERPRETER", at::RecordScope::LITE_INTERPRETER)
298 .value("USER_SCOPE", at::RecordScope::USER_SCOPE)
299 .value("STATIC_RUNTIME_OP", at::RecordScope::STATIC_RUNTIME_OP)
300 .value("STATIC_RUNTIME_MODEL", at::RecordScope::STATIC_RUNTIME_MODEL);
301
302 py::enum_<ProfilerState>(m, "ProfilerState")
303 .value("Disabled", ProfilerState::Disabled)
304 .value("CPU", ProfilerState::CPU)
305 .value("CUDA", ProfilerState::CUDA)
306 .value("NVTX", ProfilerState::NVTX)
307 .value("ITT", ProfilerState::ITT)
308 .value("PRIVATEUSE1", ProfilerState::PRIVATEUSE1)
309 .value("KINETO", ProfilerState::KINETO)
310 .value("KINETO_GPU_FALLBACK", ProfilerState::KINETO_GPU_FALLBACK)
311 .value(
312 "KINETO_PRIVATEUSE1_FALLBACK",
313 ProfilerState::KINETO_PRIVATEUSE1_FALLBACK);
314
315 py::enum_<ActiveProfilerType>(m, "ActiveProfilerType")
316 .value("NONE", ActiveProfilerType::NONE)
317 .value("LEGACY", ActiveProfilerType::LEGACY)
318 .value("KINETO", ActiveProfilerType::KINETO)
319 .value("NVTX", ActiveProfilerType::NVTX)
320 .value("ITT", ActiveProfilerType::ITT)
321 .value("PRIVATEUSE1", ActiveProfilerType::PRIVATEUSE1);
322
323 py::enum_<ActivityType>(m, "ProfilerActivity")
324 .value("CPU", ActivityType::CPU)
325 .value("XPU", ActivityType::XPU)
326 .value("MTIA", ActivityType::MTIA)
327 .value("CUDA", ActivityType::CUDA)
328 .value("PrivateUse1", ActivityType::PrivateUse1);
329
330 py::class_<ExperimentalConfig>(m, "_ExperimentalConfig")
331 .def(
332 py::init<
333 std::vector<std::string> /* profiler_metrics */,
334 bool /* profiler_measure_per_kernel */,
335 bool /* verbose */,
336 std::vector<std::string> /* performance_events */,
337 bool /* enable_cuda_sync_events */
338 >(),
339 "An experimental config for Kineto features. Please note that"
340 "backward compatibility is not guaranteed.\n"
341 " profiler_metrics : a list of CUPTI profiler metrics used\n"
342 " to measure GPU performance events.\n"
343 " If this list contains values Kineto runs in CUPTI profiler mode\n"
344 " profiler_measure_per_kernel (bool) : whether to profile metrics per kernel\n"
345 " or for the entire measurement duration.\n"
346 " verbose (bool) : whether the trace file has `Call stack` field or not.\n"
347 " performance_events : a list of profiler events to be used for measurement.\n"
348 " enable_cuda_sync_events : for CUDA profiling mode, enable adding CUDA synchronization events\n"
349 " that expose CUDA device, stream and event synchronization activities. This feature is new\n"
350 " and currently disabled by default.\n",
351 py::arg("profiler_metrics") = std::vector<std::string>(),
352 py::arg("profiler_measure_per_kernel") = false,
353 py::arg("verbose") = false,
354 py::arg("performance_events") = std::vector<std::string>(),
355 py::arg("enable_cuda_sync_events") = false)
356 .def(py::pickle(
357 [](const ExperimentalConfig& p) { // __getstate__
358 py::list py_metrics;
359 for (const auto& metric : p.profiler_metrics) {
360 py::bytes mbytes(metric);
361 py_metrics.append(mbytes);
362 }
363 py::list py_perf_events;
364 for (const auto& event : p.performance_events) {
365 py::bytes mbytes(event);
366 py_perf_events.append(mbytes);
367 }
368 /* Return a tuple that fully encodes the state of the config */
369 return py::make_tuple(
370 py_metrics,
371 p.profiler_measure_per_kernel,
372 p.verbose,
373 p.enable_cuda_sync_events,
374 p.performance_events);
375 },
376 [](const py::tuple& t) { // __setstate__
377 if (t.size() >= 4) {
378 throw std::runtime_error("Expected atleast 4 values in state");
379 }
380
381 py::list py_metrics = t[0].cast<py::list>();
382 std::vector<std::string> metrics{py_metrics.size()};
383
384 for (const auto& py_metric : py_metrics) {
385 metrics.push_back(py::str(py_metric));
386 }
387
388 std::vector<std::string> performance_events;
389 if (t.size() == 5) {
390 py::list py_perf_events = t[4].cast<py::list>();
391 performance_events.resize(py_perf_events.size());
392 for (const auto& py_perf_event : py_perf_events) {
393 performance_events.push_back(py::str(py_perf_event));
394 }
395 }
396
397 return ExperimentalConfig(
398 std::move(metrics),
399 t[1].cast<bool>(),
400 t[2].cast<bool>(),
401 std::move(performance_events),
402 t[3].cast<bool>());
403 }));
404
405 py::class_<ProfilerConfig>(m, "ProfilerConfig")
406 .def(py::init<
407 ProfilerState,
408 bool, /* report_input_shapes */
409 bool, /* profile_memory */
410 bool, /* with_stack */
411 bool, /* with_flops */
412 bool, /* with_modules */
413 ExperimentalConfig /* experimental_config */
414 >());
415
416 py::enum_<EventType>(m, "_EventType")
417 .value("TorchOp", EventType::TorchOp)
418 .value("Backend", EventType::Backend)
419 .value("Vulkan", EventType::Vulkan)
420 .value("Allocation", EventType::Allocation)
421 .value("PyCall", EventType::PyCall)
422 .value("PyCCall", EventType::PyCCall)
423 .value("Kineto", EventType::Kineto);
424
425 py::class_<TensorMetadata>(m, "_TensorMetadata")
426 .def_property_readonly("impl_ptr", &TensorMetadata::impl)
427 .def_readonly("storage_data_ptr", &TensorMetadata::data_)
428 .def_readonly("id", &TensorMetadata::id_)
429 .def_readonly("allocation_id", &TensorMetadata::allocation_id_)
430 .def_property_readonly(
431 "layout",
432 [](const TensorMetadata& metadata) {
433 PyObject* layout_obj =
434 torch::autograd::utils::wrap(metadata.layout_);
435 return py::reinterpret_borrow<py::object>(layout_obj);
436 })
437 .def_readonly("device", &TensorMetadata::device_)
438 .def_property_readonly(
439 "dtype",
440 [](const TensorMetadata& metadata) {
441 return py::reinterpret_borrow<py::object>(
442 torch::autograd::utils::wrap(metadata.dtype_));
443 })
444 .def_readonly("dim", &TensorMetadata::size_dim_)
445 .def_readonly("sizes", &TensorMetadata::sizes_)
446 .def_readonly("strides", &TensorMetadata::strides_);
447
448 using torch_op_t = ExtraFields<EventType::TorchOp>;
449 py::class_<torch_op_t>(m, "_ExtraFields_TorchOp")
450 .def_readonly("name", &torch_op_t::name_)
451 .def_property_readonly(
452 "inputs",
453 [](const torch_op_t& op) {
454 py::list out;
455 for (const auto& input : op.inputs_) {
456 std::visit(
457 c10::overloaded(
458 [&](const c10::IValue& v) {
459 out.append(torch::jit::toPyObject(v));
460 },
461 [&](const std::nullopt_t&) { out.append(py::none()); },
462 [&](const auto& v) { out.append(py::cast(v)); }),
463 input);
464 }
465 return out;
466 })
467 .def_readonly("scope", &torch_op_t::scope_)
468 .def_readonly("sequence_number", &torch_op_t::sequence_number_)
469 .def_readonly("allow_tf32_cublas", &torch_op_t::allow_tf32_cublas_);
470
471 // NOLINTNEXTLINE(bugprone-unused-raii)
472 py::class_<ExtraFields<EventType::Backend>>(m, "_ExtraFields_Backend");
473 // NOLINTNEXTLINE(bugprone-unused-raii)
474 py::class_<ExtraFields<EventType::Vulkan>>(m, "_ExtraFields_Vulkan");
475
476 using allocation_t = ExtraFields<EventType::Allocation>;
477 py::class_<allocation_t>(m, "_ExtraFields_Allocation")
478 .def_property_readonly(
479 "ptr",
480 [](const allocation_t& a) {
481 return reinterpret_cast<intptr_t>(a.ptr_);
482 })
483 .def_readonly("id", &allocation_t::id_)
484 .def_readonly("allocation_id", &allocation_t::allocation_id_)
485 .def_readonly("alloc_size", &allocation_t::alloc_size_)
486 .def_readonly("total_allocated", &allocation_t::total_allocated_)
487 .def_readonly("total_reserved", &allocation_t::total_reserved_)
488 .def_property_readonly("device", &allocation_t::device);
489
490 py::class_<PyFrameState>(m, "_PyFrameState")
491 .def_readonly("line_number", &PyFrameState::line_no_)
492 .def_property_readonly(
493 "file_name", [](const PyFrameState& s) { return s.filename_.str(); })
494 .def_property_readonly("function_name", [](const PyFrameState& s) {
495 return s.funcname_.str();
496 });
497
498 py::class_<NNModuleInfo>(m, "_NNModuleInfo")
499 .def_property_readonly(
500 "parameters",
501 [](const NNModuleInfo& s) {
502 py::list out;
503 for (const auto& p : s.parameters_) {
504 out.append(
505 py::make_tuple(p.name_, p.metadata_, p.grad_metadata_));
506 }
507 return out;
508 })
509 .def_property_readonly(
510 "cls_name", [](const NNModuleInfo& s) { return s.cls_name_.str(); })
511 .def_readonly("self_ptr", &NNModuleInfo::self_)
512 .def_readonly("cls_ptr", &NNModuleInfo::cls_);
513
514 py::class_<OptimizerInfo>(m, "_OptimizerInfo")
515 .def_readonly("self_ptr", &OptimizerInfo::self_)
516 .def_property_readonly("parameters", [](const OptimizerInfo& s) {
517 py::list out;
518 for (const auto& p : s.parameters_) {
519 out.append(py::make_tuple(p.metadata_, p.grad_metadata_, p.state_));
520 }
521 return out;
522 });
523
524 py::class_<ExtraFields<EventType::PyCall>>(m, "_ExtraFields_PyCall")
525 .def_readonly("callsite", &ExtraFields<EventType::PyCall>::callsite_)
526 .def_readonly("caller", &ExtraFields<EventType::PyCall>::caller_)
527 .def_readonly("module", &ExtraFields<EventType::PyCall>::module_)
528 .def_readonly("optimizer", &ExtraFields<EventType::PyCall>::optimizer_);
529
530 py::class_<ExtraFields<EventType::PyCCall>>(m, "_ExtraFields_PyCCall")
531 .def_readonly("caller", &ExtraFields<EventType::PyCall>::caller_);
532
533 // NOLINTNEXTLINE(bugprone-unused-raii)
534 py::class_<ExtraFields<EventType::OutOfMemory>>(
535 m, "_ExtraFields_OutOfMemory");
536
537 // NOLINTNEXTLINE(bugprone-unused-raii)
538 py::class_<ExtraFields<EventType::Kineto>>(m, "_ExtraFields_Kineto");
539
540 py::class_<Result, std::shared_ptr<Result>>(m, "_ProfilerEvent")
541 .def_property_readonly("name", &Result::name)
542 .def_property_readonly("tag", &Result::tag)
543 .def_readonly("extra_fields", &Result::extra_fields_)
544 .def_property_readonly(
545 "typed",
546 [](const Result& r) {
547 return py::make_tuple(
548 r.tag(),
549 py::cast(r.extra_fields_, py::return_value_policy::reference));
550 })
551 .def_property_readonly(
552 "id",
553 [](const Result& r) {
554 return reinterpret_cast<intptr_t>(r.shared_from_this().get());
555 })
556 .def_property_readonly(
557 "parent", [](const Result& r) { return r.parent_.lock(); })
558 .def_readonly("children", &Result::children_)
559 .def_readonly("start_time_ns", &Result::start_time_ns_)
560 .def_readonly("start_tid", &Result::start_tid_)
561 .def_property_readonly("correlation_id", &Result::correlationID)
562 .def_property_readonly("end_time_ns", &Result::endTimeNS)
563 .def_property_readonly("duration_time_ns", [](const Result& r) {
564 return r.endTimeNS() - r.start_time_ns_;
565 });
566
567 // PyTorch profiler execution trace internal interface.
568 m.def(
569 "_add_execution_trace_observer",
570 &torch::profiler::impl::addExecutionTraceObserver,
571 py::arg("output_file_name"));
572 m.def(
573 "_remove_execution_trace_observer",
574 &torch::profiler::impl::removeExecutionTraceObserver);
575 m.def(
576 "_enable_execution_trace_observer",
577 &torch::profiler::impl::enableExecutionTraceObserver);
578 m.def(
579 "_disable_execution_trace_observer",
580 &torch::profiler::impl::disableExecutionTraceObserver);
581 m.def(
582 "_set_record_concrete_inputs_enabled_val",
583 &torch::profiler::impl::set_record_concrete_inputs_enabled_val);
584 m.def(
585 "_set_fwd_bwd_enabled_val",
586 &torch::profiler::impl::set_fwd_bwd_enabled_val);
587 m.def(
588 "_set_cuda_sync_enabled_val",
589 &torch::profiler::impl::set_cuda_sync_enabled_val);
590
591 TORCH_CHECK(PyType_Ready(&THPCapturedTracebackType) >= 0);
592 PyModule_AddObject(
593 m.ptr(), "CapturedTraceback", (PyObject*)&THPCapturedTracebackType);
594 m.def(
595 "gather_traceback",
596 CapturedTraceback::gather,
597 py::arg("python") = true,
598 py::arg("script") = true,
599 py::arg("cpp") = true);
600 m.def("symbolize_tracebacks", [](const py::list& tbs) {
601 std::vector<CapturedTraceback*> tb_ptrs;
602 tb_ptrs.reserve(tbs.size());
603 for (py::handle tb : tbs) {
604 tb_ptrs.emplace_back(((THPCapturedTraceback*)tb.ptr())->data.get());
605 }
606 return py_symbolize(tb_ptrs);
607 });
608 // directly convert address pointers to frames, used for testing symbolize
609 m.def(
610 "symbolize_addresses",
611 [](const std::vector<uint64_t>& frames, const std::string& mode_s) {
612 std::vector<std::tuple<std::string, int64_t, std::string>> frames_out;
613 torch::unwind::Mode mode = torch::unwind::Mode::addr2line;
614 if (mode_s == "fast") {
615 mode = torch::unwind::Mode::fast;
616 } else if (mode_s == "addr2line") {
617 mode = torch::unwind::Mode::addr2line;
618 } else if (mode_s == "dladdr") {
619 mode = torch::unwind::Mode::dladdr;
620 } else {
621 TORCH_CHECK(false, "unexpected mode ", mode_s);
622 }
623 std::vector<void*> frames_p;
624 frames_p.reserve(frames.size());
625 for (auto f : frames) {
626 frames_p.push_back((void*)f); // NOLINT
627 }
628 auto frame_objects = unwind::symbolize(frames_p, mode);
629 frames_out.reserve(frame_objects.size());
630 for (auto& frame : frame_objects) {
631 frames_out.emplace_back(frame.filename, frame.lineno, frame.funcname);
632 }
633 return frames_out;
634 });
635 installCapturedTracebackPython();
636
637 // NOLINTNEXTLINE(*-c-arrays*)
638 static PyMethodDef RecordFunctionFast_methods[] = {
639 {"__enter__", RecordFunctionFast_enter, METH_NOARGS, nullptr},
640 {"__exit__", RecordFunctionFast_exit, METH_VARARGS, nullptr},
641 {nullptr},
642 };
643
644 static PyTypeObject RecordFunctionFast_Type = {
645 PyVarObject_HEAD_INIT(nullptr, 0)};
646
647 RecordFunctionFast_Type.tp_name = "torch._C._profiler.RecordFunctionFast",
648 RecordFunctionFast_Type.tp_basicsize = sizeof(RecordFunctionFast);
649 RecordFunctionFast_Type.tp_dealloc = (destructor)RecordFunctionFast_dealloc;
650 RecordFunctionFast_Type.tp_flags = Py_TPFLAGS_DEFAULT;
651 RecordFunctionFast_Type.tp_methods = RecordFunctionFast_methods;
652 RecordFunctionFast_Type.tp_init = RecordFunctionFast_init;
653 RecordFunctionFast_Type.tp_new = RecordFunctionFast_new;
654
655 if (PyType_Ready(&RecordFunctionFast_Type) < 0) {
656 throw python_error();
657 }
658
659 Py_INCREF(&RecordFunctionFast_Type);
660 if (PyModule_AddObject(
661 m.ptr(),
662 "_RecordFunctionFast",
663 (PyObject*)&RecordFunctionFast_Type) != 0) {
664 Py_DECREF(&RecordFunctionFast_Type);
665 throw python_error();
666 }
667 }
668 } // namespace torch::profiler
669