xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/ivalue.h>
2 #include <torch/csrc/utils/init.h>
3 #include <torch/csrc/utils/throughput_benchmark.h>
4 
5 #include <pybind11/functional.h>
6 #include <torch/csrc/utils/pybind.h>
7 
8 namespace torch::throughput_benchmark {
9 
initThroughputBenchmarkBindings(PyObject * module)10 void initThroughputBenchmarkBindings(PyObject* module) {
11   auto m = py::handle(module).cast<py::module>();
12   using namespace torch::throughput_benchmark;
13   py::class_<BenchmarkConfig>(m, "BenchmarkConfig")
14       .def(py::init<>())
15       .def_readwrite(
16           "num_calling_threads", &BenchmarkConfig::num_calling_threads)
17       .def_readwrite("num_worker_threads", &BenchmarkConfig::num_worker_threads)
18       .def_readwrite("num_warmup_iters", &BenchmarkConfig::num_warmup_iters)
19       .def_readwrite("num_iters", &BenchmarkConfig::num_iters)
20       .def_readwrite(
21           "profiler_output_path", &BenchmarkConfig::profiler_output_path);
22 
23   py::class_<BenchmarkExecutionStats>(m, "BenchmarkExecutionStats")
24       .def_readonly("latency_avg_ms", &BenchmarkExecutionStats::latency_avg_ms)
25       .def_readonly("num_iters", &BenchmarkExecutionStats::num_iters);
26 
27   py::class_<ThroughputBenchmark>(m, "ThroughputBenchmark", py::dynamic_attr())
28       .def(py::init<jit::Module>())
29       .def(py::init<py::object>())
30       .def(
31           "add_input",
32           [](ThroughputBenchmark& self, py::args args, py::kwargs kwargs) {
33             self.addInput(std::move(args), std::move(kwargs));
34           })
35       .def(
36           "run_once",
37           [](ThroughputBenchmark& self,
38              py::args args,
39              const py::kwargs& kwargs) {
40             // Depending on this being ScriptModule of nn.Module we will release
41             // the GIL or not further down in the stack
42             return self.runOnce(std::move(args), kwargs);
43           })
44       .def(
45           "benchmark",
46           [](ThroughputBenchmark& self, const BenchmarkConfig& config) {
47             // The benchmark always runs without the GIL. GIL will be used where
48             // needed. This will happen only in the nn.Module mode when
49             // manipulating inputs and running actual inference
50             pybind11::gil_scoped_release no_gil_guard;
51             return self.benchmark(config);
52           });
53 }
54 
55 } // namespace torch::throughput_benchmark
56