1 #include <torch/csrc/jit/runtime/static/init.h>
2
3 #include <torch/csrc/jit/passes/freeze_module.h>
4 #include <torch/csrc/jit/runtime/static/fusion.h>
5 #include <torch/csrc/jit/runtime/static/impl.h>
6
7 #include <utility>
8
9 // This number is a heuristic determined with pytorch/benchmark
10 static constexpr int DEFAULT_FUSION_SIZE = 4;
11
12 namespace torch::jit {
13
initStaticModuleBindings(PyObject * module)14 void initStaticModuleBindings(PyObject* module) {
15 auto m = py::handle(module).cast<py::module>();
16 py::class_<StaticModule> static_module(m, "StaticModule");
17 py::class_<StaticRuntime::IndividualMetrics>(
18 static_module, "IndividualMetrics")
19 .def_readonly("setup_time", &StaticRuntime::IndividualMetrics::setup_time)
20 .def_readonly(
21 "memory_alloc_time",
22 &StaticRuntime::IndividualMetrics::memory_alloc_time)
23 .def_readonly(
24 "memory_dealloc_time",
25 &StaticRuntime::IndividualMetrics::memory_dealloc_time)
26 .def_readonly(
27 "output_dealloc_time",
28 &StaticRuntime::IndividualMetrics::output_dealloc_time)
29 .def_readonly(
30 "first_iter_time", &StaticRuntime::IndividualMetrics::first_iter_time)
31 .def_readonly("total_time", &StaticRuntime::IndividualMetrics::total_time)
32 .def_readonly(
33 "out_nodes_count", &StaticRuntime::IndividualMetrics::out_nodes_count)
34 .def_readonly(
35 "total_nodes_count",
36 &StaticRuntime::IndividualMetrics::total_nodes_count)
37 .def_readonly(
38 "time_per_node", &StaticRuntime::IndividualMetrics::time_per_node)
39 .def_readonly(
40 "time_per_node_type",
41 &StaticRuntime::IndividualMetrics::time_per_node_type)
42 .def_readonly(
43 "percent_per_node_type",
44 &StaticRuntime::IndividualMetrics::percent_per_node_type)
45 .def_readonly(
46 "instances_per_node_type",
47 &StaticRuntime::IndividualMetrics::instances_per_node_type)
48 .def_readonly("out_nodes", &StaticRuntime::IndividualMetrics::out_nodes);
49 static_module
50 .def(
51 "__call__",
52 [](StaticModule& self,
53 const py::args& args,
54 const py::kwargs& kwargs) {
55 std::vector<c10::IValue> arg_ivalues;
56 arg_ivalues.reserve(args.size());
57 std::unordered_map<std::string, c10::IValue> kwarg_ivalues;
58 kwarg_ivalues.reserve(kwargs.size());
59 for (const auto& arg : args) {
60 auto ivalue = torch::jit::toIValue(arg, c10::AnyType::get());
61 arg_ivalues.push_back(std::move(ivalue));
62 }
63 for (const auto& kv : kwargs) {
64 kwarg_ivalues[py::cast<std::string>(kv.first)] =
65 torch::jit::toIValue(kv.second, c10::AnyType::get());
66 }
67 c10::IValue ret = self(arg_ivalues, kwarg_ivalues);
68 return toPyObject(std::move(ret));
69 })
70 .def(
71 "benchmark",
72 [](StaticModule& self,
73 const std::vector<at::Tensor>& args,
74 const std::unordered_map<std::string, at::Tensor>& kwargs,
75 const int warmup_runs,
76 const int main_runs) {
77 std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
78 std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
79 kwargs.begin(), kwargs.end()};
80 self.runtime().benchmark(
81 {arg_ivalues}, {kwarg_ivalues}, warmup_runs, main_runs);
82 })
83 .def(
84 "benchmark_individual_ops",
85 [](StaticModule& self,
86 const std::vector<at::Tensor>& args,
87 const std::unordered_map<std::string, at::Tensor>& kwargs,
88 const int warmup_runs,
89 const int main_runs) {
90 std::vector<c10::IValue> arg_ivalues{args.begin(), args.end()};
91 std::unordered_map<std::string, c10::IValue> kwarg_ivalues{
92 kwargs.begin(), kwargs.end()};
93 return self.runtime().benchmark_individual_ops(
94 {arg_ivalues}, {kwarg_ivalues}, warmup_runs, main_runs);
95 })
96 .def(
97 "runAsync",
98 [](StaticModule& self,
99 const py::tuple& args,
100 const py::dict& kwargs) {
101 std::vector<c10::IValue> arg_ivalues;
102 arg_ivalues.reserve(args.size());
103 for (const auto& elem : args) {
104 arg_ivalues.push_back(
105 torch::jit::toIValue(elem, c10::AnyType::get()));
106 }
107 std::unordered_map<std::string, c10::IValue> kwarg_ivalues;
108 kwarg_ivalues.reserve(kwargs.size());
109 for (const auto& kv : kwargs) {
110 kwarg_ivalues[py::cast<std::string>(kv.first)] =
111 torch::jit::toIValue(kv.second, c10::AnyType::get());
112 }
113 // custom executor for async op execution
114 auto task_launcher = [](const std::function<void()>& f) {
115 at::launch(f);
116 };
117 return toPyObject(self.runtime().runAsync(
118 arg_ivalues, kwarg_ivalues, task_launcher));
119 });
120 m.def(
121 "_jit_to_static_module",
122 [](const std::shared_ptr<torch::jit::Graph>& g) {
123 return StaticModule(g);
124 })
125 .def(
126 "_jit_to_static_module",
127 [](const torch::jit::Module& module) { return StaticModule(module); })
128 .def(
129 "_fuse_to_static_module",
130 [](torch::jit::Module& module, size_t min_size) {
131 module.eval();
132 module = freeze_module(module);
133
134 Method method = module.get_method("forward");
135 auto graph = method.graph();
136 fuseStaticSubgraphs(graph, min_size);
137 },
138 py::arg("module"),
139 py::arg("min_size") = DEFAULT_FUSION_SIZE)
140 .def(
141 "_fuse_to_static_module",
142 [](std::shared_ptr<torch::jit::Graph> g, size_t min_size) {
143 fuseStaticSubgraphs(std::move(g), min_size);
144 },
145 py::arg("graph"),
146 py::arg("min_size") = DEFAULT_FUSION_SIZE);
147 }
148
149 } // namespace torch::jit
150