xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/benchmark/benchmark_model.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
17 #define TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
18 
19 #include <cmath>
20 #include <cstdint>
21 #include <limits>
22 #include <ostream>
23 #include <string>
24 #include <unordered_set>
25 #include <utility>
26 #include <vector>
27 
28 #include "tensorflow/core/util/stats_calculator.h"
29 #include "tensorflow/lite/c/common.h"
30 #include "tensorflow/lite/profiling/memory_info.h"
31 #include "tensorflow/lite/profiling/memory_usage_monitor.h"
32 #include "tensorflow/lite/tools/benchmark/benchmark_params.h"
33 #include "tensorflow/lite/tools/command_line_flags.h"
34 
35 namespace tflite {
36 namespace benchmark {
37 
38 enum RunType {
39   WARMUP,
40   REGULAR,
41 };
42 
43 class BenchmarkResults {
44  public:
BenchmarkResults()45   BenchmarkResults() {}
BenchmarkResults(double model_size_mb,int64_t startup_latency_us,uint64_t input_bytes,tensorflow::Stat<int64_t> warmup_time_us,tensorflow::Stat<int64_t> inference_time_us,const profiling::memory::MemoryUsage & init_mem_usage,const profiling::memory::MemoryUsage & overall_mem_usage,float peak_mem_mb)46   BenchmarkResults(double model_size_mb, int64_t startup_latency_us,
47                    uint64_t input_bytes,
48                    tensorflow::Stat<int64_t> warmup_time_us,
49                    tensorflow::Stat<int64_t> inference_time_us,
50                    const profiling::memory::MemoryUsage& init_mem_usage,
51                    const profiling::memory::MemoryUsage& overall_mem_usage,
52                    float peak_mem_mb)
53       : model_size_mb_(model_size_mb),
54         startup_latency_us_(startup_latency_us),
55         input_bytes_(input_bytes),
56         warmup_time_us_(warmup_time_us),
57         inference_time_us_(inference_time_us),
58         init_mem_usage_(init_mem_usage),
59         overall_mem_usage_(overall_mem_usage),
60         peak_mem_mb_(peak_mem_mb) {}
61 
model_size_mb()62   const double model_size_mb() const { return model_size_mb_; }
inference_time_us()63   tensorflow::Stat<int64_t> inference_time_us() const {
64     return inference_time_us_;
65   }
warmup_time_us()66   tensorflow::Stat<int64_t> warmup_time_us() const { return warmup_time_us_; }
startup_latency_us()67   int64_t startup_latency_us() const { return startup_latency_us_; }
input_bytes()68   uint64_t input_bytes() const { return input_bytes_; }
throughput_MB_per_second()69   double throughput_MB_per_second() const {
70     double bytes_per_sec = (input_bytes_ * inference_time_us_.count() * 1e6) /
71                            inference_time_us_.sum();
72     return bytes_per_sec / (1024.0 * 1024.0);
73   }
74 
init_mem_usage()75   const profiling::memory::MemoryUsage& init_mem_usage() const {
76     return init_mem_usage_;
77   }
overall_mem_usage()78   const profiling::memory::MemoryUsage& overall_mem_usage() const {
79     return overall_mem_usage_;
80   }
peak_mem_mb()81   float peak_mem_mb() const { return peak_mem_mb_; }
82 
83  private:
84   double model_size_mb_ = 0.0;
85   int64_t startup_latency_us_ = 0;
86   uint64_t input_bytes_ = 0;
87   tensorflow::Stat<int64_t> warmup_time_us_;
88   tensorflow::Stat<int64_t> inference_time_us_;
89   profiling::memory::MemoryUsage init_mem_usage_;
90   profiling::memory::MemoryUsage overall_mem_usage_;
91   // An invalid value could happen when we don't monitor memory footprint for
92   // the inference, or the memory usage info isn't available on the benchmarking
93   // platform.
94   float peak_mem_mb_ =
95       profiling::memory::MemoryUsageMonitor::kInvalidMemUsageMB;
96 };
97 
98 class BenchmarkListener {
99  public:
100   // Called before the (outer) inference loop begins.
101   // Note that this is called *after* the interpreter has been initialized, but
102   // *before* any warmup runs have been executed.
OnBenchmarkStart(const BenchmarkParams & params)103   virtual void OnBenchmarkStart(const BenchmarkParams& params) {}
104   // Called before a single (inner) inference call starts.
OnSingleRunStart(RunType runType)105   virtual void OnSingleRunStart(RunType runType) {}
106   // Called before a single (inner) inference call ends.
OnSingleRunEnd()107   virtual void OnSingleRunEnd() {}
108   // Called after the (outer) inference loop begins.
OnBenchmarkEnd(const BenchmarkResults & results)109   virtual void OnBenchmarkEnd(const BenchmarkResults& results) {}
~BenchmarkListener()110   virtual ~BenchmarkListener() {}
111 };
112 
113 // A listener that forwards its method calls to a collection of listeners.
114 class BenchmarkListeners : public BenchmarkListener {
115  public:
116   // Added a listener to the listener collection.
117   // |listener| is not owned by the instance of |BenchmarkListeners|.
118   // |listener| should not be null and should outlast the instance of
119   // |BenchmarkListeners|.
AddListener(BenchmarkListener * listener)120   void AddListener(BenchmarkListener* listener) {
121     listeners_.push_back(listener);
122   }
123 
124   // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)125   void RemoveListeners(int index) {
126     if (index >= NumListeners()) return;
127     listeners_.resize(index);
128   }
129 
NumListeners()130   int NumListeners() const { return listeners_.size(); }
131 
OnBenchmarkStart(const BenchmarkParams & params)132   void OnBenchmarkStart(const BenchmarkParams& params) override {
133     for (auto listener : listeners_) {
134       listener->OnBenchmarkStart(params);
135     }
136   }
137 
OnSingleRunStart(RunType runType)138   void OnSingleRunStart(RunType runType) override {
139     for (auto listener : listeners_) {
140       listener->OnSingleRunStart(runType);
141     }
142   }
143 
OnSingleRunEnd()144   void OnSingleRunEnd() override {
145     for (auto listener : listeners_) {
146       listener->OnSingleRunEnd();
147     }
148   }
149 
OnBenchmarkEnd(const BenchmarkResults & results)150   void OnBenchmarkEnd(const BenchmarkResults& results) override {
151     for (auto listener : listeners_) {
152       listener->OnBenchmarkEnd(results);
153     }
154   }
155 
~BenchmarkListeners()156   ~BenchmarkListeners() override {}
157 
158  private:
159   // Use vector so listeners are invoked in the order they are added.
160   std::vector<BenchmarkListener*> listeners_;
161 };
162 
163 // Benchmark listener that just logs the results of benchmark run.
164 class BenchmarkLoggingListener : public BenchmarkListener {
165  public:
166   void OnBenchmarkEnd(const BenchmarkResults& results) override;
167 };
168 
169 template <typename T>
CreateFlag(const char * name,BenchmarkParams * params,const std::string & usage)170 Flag CreateFlag(const char* name, BenchmarkParams* params,
171                 const std::string& usage) {
172   return Flag(
173       name,
174       [params, name](const T& val, int argv_position) {
175         params->Set<T>(name, val, argv_position);
176       },
177       params->Get<T>(name), usage, Flag::kOptional);
178 }
179 
180 // Benchmarks a model.
181 //
182 // Subclasses need to implement initialization and running of the model.
183 // The results can be collected by adding BenchmarkListener(s).
184 class BenchmarkModel {
185  public:
186   static BenchmarkParams DefaultParams();
187   BenchmarkModel();
BenchmarkModel(BenchmarkParams params)188   explicit BenchmarkModel(BenchmarkParams params)
189       : params_(std::move(params)) {}
~BenchmarkModel()190   virtual ~BenchmarkModel() {}
191   virtual TfLiteStatus Init() = 0;
192   virtual TfLiteStatus Run(int argc, char** argv);
193   virtual TfLiteStatus Run();
AddListener(BenchmarkListener * listener)194   void AddListener(BenchmarkListener* listener) {
195     listeners_.AddListener(listener);
196   }
197   // Remove all listeners after [index] including the one at 'index'.
RemoveListeners(int index)198   void RemoveListeners(int index) { listeners_.RemoveListeners(index); }
NumListeners()199   int NumListeners() const { return listeners_.NumListeners(); }
200 
mutable_params()201   BenchmarkParams* mutable_params() { return &params_; }
202 
203   // Unparsable flags will remain in 'argv' in the original order and 'argc'
204   // will be updated accordingly.
205   TfLiteStatus ParseFlags(int* argc, char** argv);
206 
207  protected:
208   virtual void LogParams();
209   virtual TfLiteStatus ValidateParams();
210 
ParseFlags(int argc,char ** argv)211   TfLiteStatus ParseFlags(int argc, char** argv) {
212     return ParseFlags(&argc, argv);
213   }
214   virtual std::vector<Flag> GetFlags();
215 
216   // Get the model file size if it's available.
MayGetModelFileSize()217   virtual int64_t MayGetModelFileSize() { return -1; }
218   virtual uint64_t ComputeInputBytes() = 0;
219   virtual tensorflow::Stat<int64_t> Run(int min_num_times, float min_secs,
220                                         float max_secs, RunType run_type,
221                                         TfLiteStatus* invoke_status);
222   // Prepares input data for benchmark. This can be used to initialize input
223   // data that has non-trivial cost.
224   virtual TfLiteStatus PrepareInputData();
225 
226   virtual TfLiteStatus ResetInputsAndOutputs();
227   virtual TfLiteStatus RunImpl() = 0;
228 
229   // Create a MemoryUsageMonitor to report peak memory footprint if specified.
230   virtual std::unique_ptr<profiling::memory::MemoryUsageMonitor>
231   MayCreateMemoryUsageMonitor() const;
232 
233   BenchmarkParams params_;
234   BenchmarkListeners listeners_;
235 };
236 
237 }  // namespace benchmark
238 }  // namespace tflite
239 
240 #endif  // TENSORFLOW_LITE_TOOLS_BENCHMARK_BENCHMARK_MODEL_H_
241