xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/runtime/static/init.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/runtime/static/init.h>
2 
3 #include <torch/csrc/jit/passes/freeze_module.h>
4 #include <torch/csrc/jit/runtime/static/fusion.h>
5 #include <torch/csrc/jit/runtime/static/impl.h>
6 
7 #include <utility>
8 
9 // This number is a heuristic determined with pytorch/benchmark
10 static constexpr int DEFAULT_FUSION_SIZE = 4;
11 
12 namespace torch::jit {
13 
initStaticModuleBindings(PyObject * module)14 void initStaticModuleBindings(PyObject* module) {
15   auto m = py::handle(module).cast<py::module>();
16   py::class_<StaticModule> static_module(m, "StaticModule");
17   py::class_<StaticRuntime::IndividualMetrics>(
18       static_module, "IndividualMetrics")
19       .def_readonly("setup_time", &StaticRuntime::IndividualMetrics::setup_time)
20       .def_readonly(
21           "memory_alloc_time",
22           &StaticRuntime::IndividualMetrics::memory_alloc_time)
23       .def_readonly(
24           "memory_dealloc_time",
25           &StaticRuntime::IndividualMetrics::memory_dealloc_time)
26       .def_readonly(
27           "output_dealloc_time",
28           &StaticRuntime::IndividualMetrics::output_dealloc_time)
29       .def_readonly(
30           "first_iter_time", &StaticRuntime::IndividualMetrics::first_iter_time)
31       .def_readonly("total_time", &StaticRuntime::IndividualMetrics::total_time)
32       .def_readonly(
33           "out_nodes_count", &StaticRuntime::IndividualMetrics::out_nodes_count)
34       .def_readonly(
35           "total_nodes_count",
36           &StaticRuntime::IndividualMetrics::total_nodes_count)
37       .def_readonly(
38           "time_per_node", &StaticRuntime::IndividualMetrics::time_per_node)
39       .def_readonly(
40           "time_per_node_type",
41           &StaticRuntime::IndividualMetrics::time_per_node_type)
42       .def_readonly(
43           "percent_per_node_type",
44           &StaticRuntime::IndividualMetrics::percent_per_node_type)
45       .def_readonly(
46           "instances_per_node_type",
47           &StaticRuntime::IndividualMetrics::instances_per_node_type)
48       .def_readonly("out_nodes", &StaticRuntime::IndividualMetrics::out_nodes);
49   static_module
50       .def(
51           "__call__",
52           [](StaticModule& self,
53              const py::args& args,
54              const py::kwargs& kwargs) {
55             std::vector<c10::IValue> arg_ivalues;
56             arg_ivalues.reserve(args.size());
57             std::unordered_map<std::string, c10::IValue> kwarg_ivalues;
58             kwarg_ivalues.reserve(kwargs.size());
59             for (const auto& arg : args) {
60               auto ivalue = torch::jit::toIValue(arg, c10::AnyType::get());
61               arg_ivalues.push_back(std::move(ivalue));
62             }
63             for (const auto& kv : kwargs) {
64               kwarg_ivalues[py::cast<std::string>(kv.first)] =
65                   torch::jit::toIValue(kv.second, c10::AnyType::get());
66             }
67             c10::IValue ret = self(arg_ivalues, kwarg_ivalues);
68             return toPyObject(std::move(ret));
69           })
70       .def(
71           "benchmark",
72           [](StaticModule& self,
73              const std::vector<at::Tensor>& args,
74              const std::unordered_map<std::string, at::Tensor>& kwargs,
75              const int warmup_runs,
76              const int main_runs) {
77             std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
78             std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
79                 kwargs.begin(), kwargs.end()};
80             self.runtime().benchmark(
81                 {arg_ivalues}, {kwarg_ivalues}, warmup_runs, main_runs);
82           })
83       .def(
84           "benchmark_individual_ops",
85           [](StaticModule& self,
86              const std::vector<at::Tensor>& args,
87              const std::unordered_map<std::string, at::Tensor>& kwargs,
88              const int warmup_runs,
89              const int main_runs) {
90             std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
91             std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
92                 kwargs.begin(), kwargs.end()};
93             return self.runtime().benchmark_individual_ops(
94                 {arg_ivalues}, {kwarg_ivalues}, warmup_runs, main_runs);
95           })
96       .def(
97           "runAsync",
98           [](StaticModule& self,
99              const py::tuple& args,
100              const py::dict& kwargs) {
101             std::vector<c10::IValue> arg_ivalues;
102             arg_ivalues.reserve(args.size());
103             for (const auto& elem : args) {
104               arg_ivalues.push_back(
105                   torch::jit::toIValue(elem, c10::AnyType::get()));
106             }
107             std::unordered_map<std::string, c10::IValue> kwarg_ivalues;
108             kwarg_ivalues.reserve(kwargs.size());
109             for (const auto& kv : kwargs) {
110               kwarg_ivalues[py::cast<std::string>(kv.first)] =
111                   torch::jit::toIValue(kv.second, c10::AnyType::get());
112             }
113             // custom executor for async op execution
114             auto task_launcher = [](const std::function<void()>& f) {
115               at::launch(f);
116             };
117             return toPyObject(self.runtime().runAsync(
118                 arg_ivalues, kwarg_ivalues, task_launcher));
119           });
120   m.def(
121        "_jit_to_static_module",
122        [](const std::shared_ptr<torch::jit::Graph>& g) {
123          return StaticModule(g);
124        })
125       .def(
126           "_jit_to_static_module",
127           [](const torch::jit::Module& module) { return StaticModule(module); })
128       .def(
129           "_fuse_to_static_module",
130           [](torch::jit::Module& module, size_t min_size) {
131             module.eval();
132             module = freeze_module(module);
133 
134             Method method = module.get_method("forward");
135             auto graph = method.graph();
136             fuseStaticSubgraphs(graph, min_size);
137           },
138           py::arg("module"),
139           py::arg("min_size") = DEFAULT_FUSION_SIZE)
140       .def(
141           "_fuse_to_static_module",
142           [](std::shared_ptr<torch::jit::Graph> g, size_t min_size) {
143             fuseStaticSubgraphs(std::move(g), min_size);
144           },
145           py::arg("graph"),
146           py::arg("min_size") = DEFAULT_FUSION_SIZE);
147 }
148 
149 } // namespace torch::jit
150