1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h"
17
18 #include <utility>
19
20 #include "tensorflow/core/util/stats_calculator.h"
21 #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
22
23 extern "C" {
24
25 // -----------------------------------------------------------------------------
26 // C APIs corresponding to tflite::benchmark::BenchmarkResults type.
27 // -----------------------------------------------------------------------------
28 struct TfLiteBenchmarkResults {
29 const tflite::benchmark::BenchmarkResults* results;
30 };
31
32 // Converts the given int64_t stat into a TfLiteBenchmarkInt64Stat struct.
ConvertStat(const tensorflow::Stat<int64_t> & stat)33 TfLiteBenchmarkInt64Stat ConvertStat(const tensorflow::Stat<int64_t>& stat) {
34 return {
35 stat.empty(), stat.first(), stat.newest(), stat.max(),
36 stat.min(), stat.count(), stat.sum(), stat.squared_sum(),
37 stat.all_same(), stat.avg(), stat.std_deviation(),
38 };
39 }
40
TfLiteBenchmarkResultsGetInferenceTimeMicroseconds(const TfLiteBenchmarkResults * results)41 TfLiteBenchmarkInt64Stat TfLiteBenchmarkResultsGetInferenceTimeMicroseconds(
42 const TfLiteBenchmarkResults* results) {
43 return ConvertStat(results->results->inference_time_us());
44 }
45
TfLiteBenchmarkResultsGetWarmupTimeMicroseconds(const TfLiteBenchmarkResults * results)46 TfLiteBenchmarkInt64Stat TfLiteBenchmarkResultsGetWarmupTimeMicroseconds(
47 const TfLiteBenchmarkResults* results) {
48 return ConvertStat(results->results->warmup_time_us());
49 }
50
TfLiteBenchmarkResultsGetStartupLatencyMicroseconds(const TfLiteBenchmarkResults * results)51 int64_t TfLiteBenchmarkResultsGetStartupLatencyMicroseconds(
52 const TfLiteBenchmarkResults* results) {
53 return results->results->startup_latency_us();
54 }
55
TfLiteBenchmarkResultsGetInputBytes(const TfLiteBenchmarkResults * results)56 uint64_t TfLiteBenchmarkResultsGetInputBytes(
57 const TfLiteBenchmarkResults* results) {
58 return results->results->input_bytes();
59 }
60
TfLiteBenchmarkResultsGetThroughputMbPerSecond(const TfLiteBenchmarkResults * results)61 double TfLiteBenchmarkResultsGetThroughputMbPerSecond(
62 const TfLiteBenchmarkResults* results) {
63 return results->results->throughput_MB_per_second();
64 }
65
66 // -----------------------------------------------------------------------------
67 // C APIs corresponding to tflite::benchmark::BenchmarkListener type.
68 // -----------------------------------------------------------------------------
69 class BenchmarkListenerAdapter : public tflite::benchmark::BenchmarkListener {
70 public:
OnBenchmarkStart(const tflite::benchmark::BenchmarkParams & params)71 void OnBenchmarkStart(
72 const tflite::benchmark::BenchmarkParams& params) override {
73 if (on_benchmark_start_fn_ != nullptr) {
74 on_benchmark_start_fn_(user_data_);
75 }
76 }
77
OnSingleRunStart(tflite::benchmark::RunType runType)78 void OnSingleRunStart(tflite::benchmark::RunType runType) override {
79 if (on_single_run_start_fn_ != nullptr) {
80 on_single_run_start_fn_(user_data_, runType == tflite::benchmark::WARMUP
81 ? TfLiteBenchmarkWarmup
82 : TfLiteBenchmarkRegular);
83 }
84 }
85
OnSingleRunEnd()86 void OnSingleRunEnd() override {
87 if (on_single_run_end_fn_ != nullptr) {
88 on_single_run_end_fn_(user_data_);
89 }
90 }
91
OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults & results)92 void OnBenchmarkEnd(
93 const tflite::benchmark::BenchmarkResults& results) override {
94 if (on_benchmark_end_fn_ != nullptr) {
95 TfLiteBenchmarkResults* wrapper = new TfLiteBenchmarkResults{&results};
96 on_benchmark_end_fn_(user_data_, wrapper);
97 delete wrapper;
98 }
99 }
100
101 // Keep the user_data pointer provided when setting the callbacks.
102 void* user_data_;
103
104 // Function pointers set by the TfLiteBenchmarkListenerSetCallbacks call.
105 // Only non-null callbacks will be actually called.
106 void (*on_benchmark_start_fn_)(void* user_data);
107 void (*on_single_run_start_fn_)(void* user_data,
108 TfLiteBenchmarkRunType runType);
109 void (*on_single_run_end_fn_)(void* user_data);
110 void (*on_benchmark_end_fn_)(void* user_data,
111 TfLiteBenchmarkResults* results);
112 };
113
114 struct TfLiteBenchmarkListener {
115 std::unique_ptr<BenchmarkListenerAdapter> adapter;
116 };
117
TfLiteBenchmarkListenerCreate()118 TfLiteBenchmarkListener* TfLiteBenchmarkListenerCreate() {
119 std::unique_ptr<BenchmarkListenerAdapter> adapter(
120 new BenchmarkListenerAdapter());
121 return new TfLiteBenchmarkListener{std::move(adapter)};
122 }
123
TfLiteBenchmarkListenerDelete(TfLiteBenchmarkListener * listener)124 void TfLiteBenchmarkListenerDelete(TfLiteBenchmarkListener* listener) {
125 delete listener;
126 }
127
TfLiteBenchmarkListenerSetCallbacks(TfLiteBenchmarkListener * listener,void * user_data,void (* on_benchmark_start_fn)(void * user_data),void (* on_single_run_start_fn)(void * user_data,TfLiteBenchmarkRunType runType),void (* on_single_run_end_fn)(void * user_data),void (* on_benchmark_end_fn)(void * user_data,TfLiteBenchmarkResults * results))128 void TfLiteBenchmarkListenerSetCallbacks(
129 TfLiteBenchmarkListener* listener, void* user_data,
130 void (*on_benchmark_start_fn)(void* user_data),
131 void (*on_single_run_start_fn)(void* user_data,
132 TfLiteBenchmarkRunType runType),
133 void (*on_single_run_end_fn)(void* user_data),
134 void (*on_benchmark_end_fn)(void* user_data,
135 TfLiteBenchmarkResults* results)) {
136 listener->adapter->user_data_ = user_data;
137 listener->adapter->on_benchmark_start_fn_ = on_benchmark_start_fn;
138 listener->adapter->on_single_run_start_fn_ = on_single_run_start_fn;
139 listener->adapter->on_single_run_end_fn_ = on_single_run_end_fn;
140 listener->adapter->on_benchmark_end_fn_ = on_benchmark_end_fn;
141 }
142
143 // -----------------------------------------------------------------------------
144 // C APIs corresponding to tflite::benchmark::BenchmarkTfLiteModel type.
145 // -----------------------------------------------------------------------------
146 struct TfLiteBenchmarkTfLiteModel {
147 std::unique_ptr<tflite::benchmark::BenchmarkTfLiteModel> benchmark_model;
148 };
149
TfLiteBenchmarkTfLiteModelCreate()150 TfLiteBenchmarkTfLiteModel* TfLiteBenchmarkTfLiteModelCreate() {
151 std::unique_ptr<tflite::benchmark::BenchmarkTfLiteModel> benchmark_model(
152 new tflite::benchmark::BenchmarkTfLiteModel());
153 return new TfLiteBenchmarkTfLiteModel{std::move(benchmark_model)};
154 }
155
TfLiteBenchmarkTfLiteModelDelete(TfLiteBenchmarkTfLiteModel * benchmark_model)156 void TfLiteBenchmarkTfLiteModelDelete(
157 TfLiteBenchmarkTfLiteModel* benchmark_model) {
158 delete benchmark_model;
159 }
160
TfLiteBenchmarkTfLiteModelInit(TfLiteBenchmarkTfLiteModel * benchmark_model)161 TfLiteStatus TfLiteBenchmarkTfLiteModelInit(
162 TfLiteBenchmarkTfLiteModel* benchmark_model) {
163 return benchmark_model->benchmark_model->Init();
164 }
165
TfLiteBenchmarkTfLiteModelRun(TfLiteBenchmarkTfLiteModel * benchmark_model)166 TfLiteStatus TfLiteBenchmarkTfLiteModelRun(
167 TfLiteBenchmarkTfLiteModel* benchmark_model) {
168 return benchmark_model->benchmark_model->Run();
169 }
170
TfLiteBenchmarkTfLiteModelRunWithArgs(TfLiteBenchmarkTfLiteModel * benchmark_model,int argc,char ** argv)171 TfLiteStatus TfLiteBenchmarkTfLiteModelRunWithArgs(
172 TfLiteBenchmarkTfLiteModel* benchmark_model, int argc, char** argv) {
173 return benchmark_model->benchmark_model->Run(argc, argv);
174 }
175
TfLiteBenchmarkTfLiteModelAddListener(TfLiteBenchmarkTfLiteModel * benchmark_model,const TfLiteBenchmarkListener * listener)176 void TfLiteBenchmarkTfLiteModelAddListener(
177 TfLiteBenchmarkTfLiteModel* benchmark_model,
178 const TfLiteBenchmarkListener* listener) {
179 return benchmark_model->benchmark_model->AddListener(listener->adapter.get());
180 }
181
182 } // extern "C"
183