1 #pragma once
2
3 #include <ATen/core/ivalue.h>
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/jit/api/module.h>
6 #include <torch/csrc/utils/pybind.h>
7
8 #include <torch/csrc/jit/python/pybind_utils.h>
9
10 #include <iosfwd>
11 #include <memory>
12 #include <string>
13 #include <vector>
14
15 namespace py = pybind11;
16
17 namespace torch::throughput_benchmark {
18
19 /**
20 * The struct is used to provide results of a benchmark to the caller
21 * In the future all additional statics should be added here.
22 */
23 struct BenchmarkExecutionStats {
24 float latency_avg_ms{-1};
25 int64_t num_iters{-1};
26 };
27
28 std::ostream& operator<<(
29 std::ostream& os,
30 const BenchmarkExecutionStats& value);
31
32 /**
33 * Use this struct in order to configure a throughput benchmark run.
34 * This struct should include parameters related to threading, batching, number
35 * of iterations, warm-up, etc. More configs can be added as needed.
36 * General rule here is that only things that c++ must(!) to be aware of should
37 * be here. If we can keep other parts in python, we should keep them there.
38 * This is typical for things that are not perf critical and don't affect
39 * execution statistics benchmark returns.
40 */
41 struct BenchmarkConfig {
42 public:
43 // Calling threads are those threads that are calling into a module in
44 // parallel.
45 int num_calling_threads{1};
46 // Worker threads are not supported yet. This is just an example that we plan
47 // to support some sort of multi-threaded forward calls. We may change this
48 // setting in the future to support different intra and inter op parallelism
49 // which is not available in PyTorch yet
50 int num_worker_threads{1};
51 // Warmup iters are used to make sure we run a module a few times before
52 // actually measuring things. This way we avoid cold caches and any other
53 // similar problems
54 int num_warmup_iters{1};
55 // Number of iterations the benchmark should run with. This number is separate
56 // from the warmup iterations
57 int64_t num_iters{100};
58 // If set autograd profiler will be enabled. I.e. this variable would be
59 // created before the main benchmark loop (but after the warmup):
60 // RecordProfile guard(profiler_output_path);
61 std::string profiler_output_path{""};
62 };
63
64 namespace detail {
65
66 /**
67 * A helper class to abstract out different models we test throughput of
68 */
69 template <class Input, class Output, class Model>
70 class BenchmarkHelper {
71 public:
72 BenchmarkHelper();
BenchmarkHelper(Model model)73 explicit BenchmarkHelper(Model model)
74 : model_(std::move(model)), initialized_(true) {}
75
76 // This method to be used in benchmark() method
77 // Note that there is no result. This way we don't have to call this under GIL
78 // even when running in the nn.Module mode. Otherwise destructor of the result
79 // would race with Python
80 void runOnce(Input&&) const;
81 // This method is to be used when calling from Python directly
82 Output runOnce(const py::args&, const py::kwargs&) const;
83 // Aggregate input in the format Model expects in order to avoid further
84 // conversions at the benchmark time
85 void addInput(py::args&&, py::kwargs&&);
86 void addInput(Input&&);
87 BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const;
88
initialized()89 bool initialized() const {
90 return initialized_;
91 }
92
93 // Destructor doesn't require the GIL because it is going to be executed on
94 // the PyThon thread
95 std::vector<Input> inputs_;
96 Model model_;
97 bool initialized_{false};
98 };
99
100 struct C10_HIDDEN ModuleInput {
101 ModuleInput(ModuleInput&& other) = default;
102
103 ModuleInput(const ModuleInput&) = delete;
104 ModuleInput& operator=(ModuleInput& other) = delete;
105 ModuleInput& operator=(ModuleInput&& other) = delete;
106
ModuleInputModuleInput107 ModuleInput(py::args&& args, py::kwargs&& kwargs)
108 : args(std::move(args)), kwargs(std::move(kwargs)) {}
109
110 py::args args;
111 py::kwargs kwargs;
112 };
113 typedef py::object ModuleOutput;
114 typedef std::vector<at::IValue> ScriptModuleInput;
115 typedef at::IValue ScriptModuleOutput;
116
117 template <class Input>
118 Input cloneInput(const Input& input);
119
120 typedef BenchmarkHelper<ScriptModuleInput, at::IValue, jit::Module>
121 ScriptModuleBenchmark;
122 template <>
123 inline BenchmarkHelper<ScriptModuleInput, at::IValue, jit::Module>::
BenchmarkHelper()124 BenchmarkHelper()
125 : model_("Module", std::make_shared<jit::CompilationUnit>()),
126 initialized_(false) {}
127 typedef BenchmarkHelper<ModuleInput, py::object, py::object> ModuleBenchmark;
128 template <>
BenchmarkHelper()129 inline BenchmarkHelper<ModuleInput, py::object, py::object>::BenchmarkHelper()
130 : initialized_(false) {}
131
132 template <>
133 void ScriptModuleBenchmark::runOnce(ScriptModuleInput&& input) const;
134
135 template <>
136 ScriptModuleOutput ScriptModuleBenchmark::runOnce(
137 const py::args& args,
138 const py::kwargs& kwargs) const;
139
140 template <>
141 void ModuleBenchmark::runOnce(ModuleInput&& input) const;
142
143 template <>
144 ModuleOutput ModuleBenchmark::runOnce(
145 const py::args& args,
146 const py::kwargs& kwargs) const;
147
148 template <>
149 void ScriptModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs);
150 template <>
151 void ScriptModuleBenchmark::addInput(ScriptModuleInput&& input);
152
153 template <>
154 void ModuleBenchmark::addInput(py::args&& args, py::kwargs&& kwargs);
155
156 } // namespace detail
157
158 /**
159 * This class is a small c++ component responsible for executing a PyTorch
160 * module under an inference server like load. It can emulate multiple calling
161 * threads to a single module provided. In the future we plan to enhance this
162 * component to support inter and intra-op parallelism as well as multiple
163 * models running in a single process.
164 *
165 * For current available configurations refer to the BenchmarkConfig
166 * documentation
167 *
168 * The class supports working with either nn.Module or ScriptModule.
169 * Under the hood it just dispatches to corresponding specialization of
170 * class BenchmarkHelper<Input, Output, Model>
171 */
172 class C10_HIDDEN ThroughputBenchmark {
173 public:
174 explicit ThroughputBenchmark(const jit::Module& module);
175 explicit ThroughputBenchmark(py::object module);
176
177 // Add one more input example. This input example should be in the exact
178 // format the module under test expects. It is responsibility of the module to
179 // perform any such format checks, the benchmark doesn't perform any
180 // validation of its own
181 void addInput(py::args args, py::kwargs kwargs);
182
183 // Equivalent to just running the model directly on the given input
184 py::object runOnce(const py::args& args, const py::kwargs& kwargs);
185
186 // The main method of the class allows to perform a multi-threaded benchmark
187 // It returns BenchmarkExecutionStats object with a lot of useful statistics
188 // about runtime execution. We can enhance this class in the future to provide
189 // more information to the user
190 BenchmarkExecutionStats benchmark(const BenchmarkConfig& config) const;
191
192 private:
193 detail::ScriptModuleBenchmark script_module_;
194 detail::ModuleBenchmark module_;
195 };
196 } // namespace torch::throughput_benchmark
197
198 #include <torch/csrc/utils/throughput_benchmark-inl.h>
199