1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
17 #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
18
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22 #include <ostream>
23 #include <string>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27
28 #include "tensorflow/core/util/stats_calculator.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/profiling/memory_info.h"
31 #include "tensorflow/lite/profiling/memory_usage_monitor.h"
32 #include "tensorflow/lite/tools/benchmark/benchmark_params.h"
33 #include "tensorflow/lite/tools/command_line_flags.h"
34
35 namespace tflite {
36 namespace benchmark {
37
38 enum RunType {
39 WARMUP,
40 REGULAR,
41 };
42
43 class BenchmarkResults {
44 public:
BenchmarkResults()45 BenchmarkResults() {}
BenchmarkResults(double model_size_mb,int64_t startup_latency_us,uint64_t input_bytes,tensorflow::Stat<int64_t> warmup_time_us,tensorflow::Stat<int64_t> inference_time_us,const profiling::memory::MemoryUsage & init_mem_usage,const profiling::memory::MemoryUsage & overall_mem_usage,float peak_mem_mb)46 BenchmarkResults(double model_size_mb, int64_t startup_latency_us,
47 uint64_t input_bytes,
48 tensorflow::Stat<int64_t> warmup_time_us,
49 tensorflow::Stat<int64_t> inference_time_us,
50 const profiling::memory::MemoryUsage& init_mem_usage,
51 const profiling::memory::MemoryUsage& overall_mem_usage,
52 float peak_mem_mb)
53 : model_size_mb_(model_size_mb),
54 startup_latency_us_(startup_latency_us),
55 input_bytes_(input_bytes),
56 warmup_time_us_(warmup_time_us),
57 inference_time_us_(inference_time_us),
58 init_mem_usage_(init_mem_usage),
59 overall_mem_usage_(overall_mem_usage),
60 peak_mem_mb_(peak_mem_mb) {}
61
model_size_mb()62 const double model_size_mb() const { return model_size_mb_; }
inference_time_us()63 tensorflow::Stat<int64_t> inference_time_us() const {
64 return inference_time_us_;
65 }
warmup_time_us()66 tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
startup_latency_us()67 int64_t startup_latency_us() const { return startup_latency_us_; }
input_bytes()68 uint64_t input_bytes() const { return input_bytes_; }
throughput_MB_per_second()69 double throughput_MB_per_second() const {
70 double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
71 inference_time_us_.sum();
72 return bytes_per_sec / (1024.0 * 1024.0);
73 }
74
init_mem_usage()75 const profiling::memory::MemoryUsage& init_mem_usage() const {
76 return init_mem_usage_;
77 }
overall_mem_usage()78 const profiling::memory::MemoryUsage& overall_mem_usage() const {
79 return overall_mem_usage_;
80 }
peak_mem_mb()81 float peak_mem_mb() const { return peak_mem_mb_; }
82
83 private:
84 double model_size_mb_ = 0.0;
85 int64_t startup_latency_us_ = 0;
86 uint64_t input_bytes_ = 0;
87 tensorflow::Stat<int64_t> warmup_time_us_;
88 tensorflow::Stat<int64_t> inference_time_us_;
89 profiling::memory::MemoryUsage init_mem_usage_;
90 profiling::memory::MemoryUsage overall_mem_usage_;
91 // An invalid value could happen when we don't monitor memory footprint for
92 // the inference, or the memory usage info isn't available on the benchmarking
93 // platform.
94 float peak_mem_mb_ =
95 profiling::memory::MemoryUsageMonitor::kInvalidMemUsageMB;
96 };
97
98 class BenchmarkListener {
99 public:
100 // Called before the (outer) inference loop begins.
101 // Note that this is called *after* the interpreter has been initialized, but
102 // *before* any warmup runs have been executed.
OnBenchmarkStart(const BenchmarkParams & params)103 virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
104 // Called before a single (inner) inference call starts.
OnSingleRunStart(RunType runType)105 virtual void OnSingleRunStart(RunType runType) {}
106 // Called before a single (inner) inference call ends.
OnSingleRunEnd()107 virtual void OnSingleRunEnd() {}
108 // Called after the (outer) inference loop begins.
OnBenchmarkEnd(const BenchmarkResults & results)109 virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
~BenchmarkListener()110 virtual ~BenchmarkListener() {}
111 };
112
113 // A listener that forwards its method calls to a collection of listeners.
114 class BenchmarkListeners : public BenchmarkListener {
115 public:
116 // Added a listener to the listener collection.
117 // |listener| is not owned by the instance of |BenchmarkListeners|.
118 // |listener| should not be null and should outlast the instance of
119 // |BenchmarkListeners|.
AddListener(BenchmarkListener * listener)120 void AddListener(BenchmarkListener* listener) {
121 listeners_.push_back(listener);
122 }
123
124 // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)125 void RemoveListeners(int index) {
126 if (index >= NumListeners()) return;
127 listeners_.resize(index);
128 }
129
NumListeners()130 int NumListeners() const { return listeners_.size(); }
131
OnBenchmarkStart(const BenchmarkParams & params)132 void OnBenchmarkStart(const BenchmarkParams& params) override {
133 for (auto listener : listeners_) {
134 listener->OnBenchmarkStart(params);
135 }
136 }
137
OnSingleRunStart(RunType runType)138 void OnSingleRunStart(RunType runType) override {
139 for (auto listener : listeners_) {
140 listener->OnSingleRunStart(runType);
141 }
142 }
143
OnSingleRunEnd()144 void OnSingleRunEnd() override {
145 for (auto listener : listeners_) {
146 listener->OnSingleRunEnd();
147 }
148 }
149
OnBenchmarkEnd(const BenchmarkResults & results)150 void OnBenchmarkEnd(const BenchmarkResults& results) override {
151 for (auto listener : listeners_) {
152 listener->OnBenchmarkEnd(results);
153 }
154 }
155
~BenchmarkListeners()156 ~BenchmarkListeners() override {}
157
158 private:
159 // Use vector so listeners are invoked in the order they are added.
160 std::vector<BenchmarkListener*> listeners_;
161 };
162
163 // Benchmark listener that just logs the results of benchmark run.
164 class BenchmarkLoggingListener : public BenchmarkListener {
165 public:
166 void OnBenchmarkEnd(const BenchmarkResults& results) override;
167 };
168
169 template <typename T>
CreateFlag(const char * name,BenchmarkParams * params,const std::string & usage)170 Flag CreateFlag(const char* name, BenchmarkParams* params,
171 const std::string& usage) {
172 return Flag(
173 name,
174 [params, name](const T& val, int argv_position) {
175 params->Set<T>(name, val, argv_position);
176 },
177 params->Get<T>(name), usage, Flag::kOptional);
178 }
179
180 // Benchmarks a model.
181 //
182 // Subclasses need to implement initialization and running of the model.
183 // The results can be collected by adding BenchmarkListener(s).
184 class BenchmarkModel {
185 public:
186 static BenchmarkParams DefaultParams();
187 BenchmarkModel();
BenchmarkModel(BenchmarkParams params)188 explicit BenchmarkModel(BenchmarkParams params)
189 : params_(std::move(params)) {}
~BenchmarkModel()190 virtual ~BenchmarkModel() {}
191 virtual TfLiteStatus Init() = 0;
192 virtual TfLiteStatus Run(int argc, char** argv);
193 virtual TfLiteStatus Run();
AddListener(BenchmarkListener * listener)194 void AddListener(BenchmarkListener* listener) {
195 listeners_.AddListener(listener);
196 }
197 // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)198 void RemoveListeners(int index) { listeners_.RemoveListeners(index); }
NumListeners()199 int NumListeners() const { return listeners_.NumListeners(); }
200
mutable_params()201 BenchmarkParams* mutable_params() { return ¶ms_; }
202
203 // Unparsable flags will remain in 'argv' in the original order and 'argc'
204 // will be updated accordingly.
205 TfLiteStatus ParseFlags(int* argc, char** argv);
206
207 protected:
208 virtual void LogParams();
209 virtual TfLiteStatus ValidateParams();
210
ParseFlags(int argc,char ** argv)211 TfLiteStatus ParseFlags(int argc, char** argv) {
212 return ParseFlags(&argc, argv);
213 }
214 virtual std::vector<Flag> GetFlags();
215
216 // Get the model file size if it's available.
MayGetModelFileSize()217 virtual int64_t MayGetModelFileSize() { return -1; }
218 virtual uint64_t ComputeInputBytes() = 0;
219 virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
220 float max_secs, RunType run_type,
221 TfLiteStatus* invoke_status);
222 // Prepares input data for benchmark. This can be used to initialize input
223 // data that has non-trivial cost.
224 virtual TfLiteStatus PrepareInputData();
225
226 virtual TfLiteStatus ResetInputsAndOutputs();
227 virtual TfLiteStatus RunImpl() = 0;
228
229 // Create a MemoryUsageMonitor to report peak memory footprint if specified.
230 virtual std::unique_ptr<profiling::memory::MemoryUsageMonitor>
231 MayCreateMemoryUsageMonitor() const;
232
233 BenchmarkParams params_;
234 BenchmarkListeners listeners_;
235 };
236
237 } // namespace benchmark
238 } // namespace tflite
239
240 #endif // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
241