xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/tools/benchmark/experimental/c/benchmark_c_api.h"
17 
18 #include <utility>
19 
20 #include "tensorflow/core/util/stats_calculator.h"
21 #include "tensorflow/lite/tools/benchmark/benchmark_tflite_model.h"
22 
23 extern "C" {
24 
25 // -----------------------------------------------------------------------------
26 // C APIs corresponding to tflite::benchmark::BenchmarkResults type.
27 // -----------------------------------------------------------------------------
28 struct TfLiteBenchmarkResults {
29   const tflite::benchmark::BenchmarkResults* results;
30 };
31 
32 // Converts the given int64_t stat into a TfLiteBenchmarkInt64Stat struct.
ConvertStat(const tensorflow::Stat<int64_t> & stat)33 TfLiteBenchmarkInt64Stat ConvertStat(const tensorflow::Stat<int64_t>& stat) {
34   return {
35       stat.empty(),    stat.first(), stat.newest(),        stat.max(),
36       stat.min(),      stat.count(), stat.sum(),           stat.squared_sum(),
37       stat.all_same(), stat.avg(),   stat.std_deviation(),
38   };
39 }
40 
TfLiteBenchmarkResultsGetInferenceTimeMicroseconds(const TfLiteBenchmarkResults * results)41 TfLiteBenchmarkInt64Stat TfLiteBenchmarkResultsGetInferenceTimeMicroseconds(
42     const TfLiteBenchmarkResults* results) {
43   return ConvertStat(results->results->inference_time_us());
44 }
45 
TfLiteBenchmarkResultsGetWarmupTimeMicroseconds(const TfLiteBenchmarkResults * results)46 TfLiteBenchmarkInt64Stat TfLiteBenchmarkResultsGetWarmupTimeMicroseconds(
47     const TfLiteBenchmarkResults* results) {
48   return ConvertStat(results->results->warmup_time_us());
49 }
50 
TfLiteBenchmarkResultsGetStartupLatencyMicroseconds(const TfLiteBenchmarkResults * results)51 int64_t TfLiteBenchmarkResultsGetStartupLatencyMicroseconds(
52     const TfLiteBenchmarkResults* results) {
53   return results->results->startup_latency_us();
54 }
55 
TfLiteBenchmarkResultsGetInputBytes(const TfLiteBenchmarkResults * results)56 uint64_t TfLiteBenchmarkResultsGetInputBytes(
57     const TfLiteBenchmarkResults* results) {
58   return results->results->input_bytes();
59 }
60 
TfLiteBenchmarkResultsGetThroughputMbPerSecond(const TfLiteBenchmarkResults * results)61 double TfLiteBenchmarkResultsGetThroughputMbPerSecond(
62     const TfLiteBenchmarkResults* results) {
63   return results->results->throughput_MB_per_second();
64 }
65 
66 // -----------------------------------------------------------------------------
67 // C APIs corresponding to tflite::benchmark::BenchmarkListener type.
68 // -----------------------------------------------------------------------------
69 class BenchmarkListenerAdapter : public tflite::benchmark::BenchmarkListener {
70  public:
OnBenchmarkStart(const tflite::benchmark::BenchmarkParams & params)71   void OnBenchmarkStart(
72       const tflite::benchmark::BenchmarkParams& params) override {
73     if (on_benchmark_start_fn_ != nullptr) {
74       on_benchmark_start_fn_(user_data_);
75     }
76   }
77 
OnSingleRunStart(tflite::benchmark::RunType runType)78   void OnSingleRunStart(tflite::benchmark::RunType runType) override {
79     if (on_single_run_start_fn_ != nullptr) {
80       on_single_run_start_fn_(user_data_, runType == tflite::benchmark::WARMUP
81                                               ? TfLiteBenchmarkWarmup
82                                               : TfLiteBenchmarkRegular);
83     }
84   }
85 
OnSingleRunEnd()86   void OnSingleRunEnd() override {
87     if (on_single_run_end_fn_ != nullptr) {
88       on_single_run_end_fn_(user_data_);
89     }
90   }
91 
OnBenchmarkEnd(const tflite::benchmark::BenchmarkResults & results)92   void OnBenchmarkEnd(
93       const tflite::benchmark::BenchmarkResults& results) override {
94     if (on_benchmark_end_fn_ != nullptr) {
95       TfLiteBenchmarkResults* wrapper = new TfLiteBenchmarkResults{&results};
96       on_benchmark_end_fn_(user_data_, wrapper);
97       delete wrapper;
98     }
99   }
100 
101   // Keep the user_data pointer provided when setting the callbacks.
102   void* user_data_;
103 
104   // Function pointers set by the TfLiteBenchmarkListenerSetCallbacks call.
105   // Only non-null callbacks will be actually called.
106   void (*on_benchmark_start_fn_)(void* user_data);
107   void (*on_single_run_start_fn_)(void* user_data,
108                                   TfLiteBenchmarkRunType runType);
109   void (*on_single_run_end_fn_)(void* user_data);
110   void (*on_benchmark_end_fn_)(void* user_data,
111                                TfLiteBenchmarkResults* results);
112 };
113 
114 struct TfLiteBenchmarkListener {
115   std::unique_ptr<BenchmarkListenerAdapter> adapter;
116 };
117 
TfLiteBenchmarkListenerCreate()118 TfLiteBenchmarkListener* TfLiteBenchmarkListenerCreate() {
119   std::unique_ptr<BenchmarkListenerAdapter> adapter(
120       new BenchmarkListenerAdapter());
121   return new TfLiteBenchmarkListener{std::move(adapter)};
122 }
123 
TfLiteBenchmarkListenerDelete(TfLiteBenchmarkListener * listener)124 void TfLiteBenchmarkListenerDelete(TfLiteBenchmarkListener* listener) {
125   delete listener;
126 }
127 
TfLiteBenchmarkListenerSetCallbacks(TfLiteBenchmarkListener * listener,void * user_data,void (* on_benchmark_start_fn)(void * user_data),void (* on_single_run_start_fn)(void * user_data,TfLiteBenchmarkRunType runType),void (* on_single_run_end_fn)(void * user_data),void (* on_benchmark_end_fn)(void * user_data,TfLiteBenchmarkResults * results))128 void TfLiteBenchmarkListenerSetCallbacks(
129     TfLiteBenchmarkListener* listener, void* user_data,
130     void (*on_benchmark_start_fn)(void* user_data),
131     void (*on_single_run_start_fn)(void* user_data,
132                                    TfLiteBenchmarkRunType runType),
133     void (*on_single_run_end_fn)(void* user_data),
134     void (*on_benchmark_end_fn)(void* user_data,
135                                 TfLiteBenchmarkResults* results)) {
136   listener->adapter->user_data_ = user_data;
137   listener->adapter->on_benchmark_start_fn_ = on_benchmark_start_fn;
138   listener->adapter->on_single_run_start_fn_ = on_single_run_start_fn;
139   listener->adapter->on_single_run_end_fn_ = on_single_run_end_fn;
140   listener->adapter->on_benchmark_end_fn_ = on_benchmark_end_fn;
141 }
142 
143 // -----------------------------------------------------------------------------
144 // C APIs corresponding to tflite::benchmark::BenchmarkTfLiteModel type.
145 // -----------------------------------------------------------------------------
146 struct TfLiteBenchmarkTfLiteModel {
147   std::unique_ptr<tflite::benchmark::BenchmarkTfLiteModel> benchmark_model;
148 };
149 
TfLiteBenchmarkTfLiteModelCreate()150 TfLiteBenchmarkTfLiteModel* TfLiteBenchmarkTfLiteModelCreate() {
151   std::unique_ptr<tflite::benchmark::BenchmarkTfLiteModel> benchmark_model(
152       new tflite::benchmark::BenchmarkTfLiteModel());
153   return new TfLiteBenchmarkTfLiteModel{std::move(benchmark_model)};
154 }
155 
TfLiteBenchmarkTfLiteModelDelete(TfLiteBenchmarkTfLiteModel * benchmark_model)156 void TfLiteBenchmarkTfLiteModelDelete(
157     TfLiteBenchmarkTfLiteModel* benchmark_model) {
158   delete benchmark_model;
159 }
160 
TfLiteBenchmarkTfLiteModelInit(TfLiteBenchmarkTfLiteModel * benchmark_model)161 TfLiteStatus TfLiteBenchmarkTfLiteModelInit(
162     TfLiteBenchmarkTfLiteModel* benchmark_model) {
163   return benchmark_model->benchmark_model->Init();
164 }
165 
TfLiteBenchmarkTfLiteModelRun(TfLiteBenchmarkTfLiteModel * benchmark_model)166 TfLiteStatus TfLiteBenchmarkTfLiteModelRun(
167     TfLiteBenchmarkTfLiteModel* benchmark_model) {
168   return benchmark_model->benchmark_model->Run();
169 }
170 
TfLiteBenchmarkTfLiteModelRunWithArgs(TfLiteBenchmarkTfLiteModel * benchmark_model,int argc,char ** argv)171 TfLiteStatus TfLiteBenchmarkTfLiteModelRunWithArgs(
172     TfLiteBenchmarkTfLiteModel* benchmark_model, int argc, char** argv) {
173   return benchmark_model->benchmark_model->Run(argc, argv);
174 }
175 
TfLiteBenchmarkTfLiteModelAddListener(TfLiteBenchmarkTfLiteModel * benchmark_model,const TfLiteBenchmarkListener * listener)176 void TfLiteBenchmarkTfLiteModelAddListener(
177     TfLiteBenchmarkTfLiteModel* benchmark_model,
178     const TfLiteBenchmarkListener* listener) {
179   return benchmark_model->benchmark_model->AddListener(listener->adapter.get());
180 }
181 
182 }  // extern "C"
183