xref: /aosp_15_r20/external/tflite-support/tensorflow_lite_support/cc/task/core/tflite_engine.cc (revision b16991f985baa50654c05c5adbb3c8bbcfb40082)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
17 
18 #include <unistd.h>
19 
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/lite/builtin_ops.h"
23 #include "tensorflow/lite/stderr_reporter.h"
24 #include "tensorflow_lite_support/cc/common.h"
25 #include "tensorflow_lite_support/cc/port/status_macros.h"
26 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
27 
28 #if TFLITE_USE_C_API
29 #include "tensorflow/lite/c/c_api_experimental.h"
30 #else
31 #include "tensorflow/lite/kernels/register.h"
32 #endif
33 
34 namespace tflite {
35 namespace task {
36 namespace core {
37 
38 #ifdef __ANDROID__
39 // https://github.com/opencv/opencv/issues/14906
40 // "ios_base::Init" object is not a part of Android's "iostream" header (in case
41 // of clang toolchain, NDK 20).
42 //
43 // Ref1:
44 // https://en.cppreference.com/w/cpp/io/ios_base/Init
45 //       The header <iostream> behaves as if it defines (directly or indirectly)
46 //       an instance of std::ios_base::Init with static storage duration
47 //
48 // Ref2:
49 // https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74
50 static std::ios_base::Init s_iostream_initializer;
51 #endif
52 
53 using ::absl::StatusCode;
54 using ::tflite::support::CreateStatusWithPayload;
55 using ::tflite::support::TfLiteSupportStatus;
56 
Report(const char * format,va_list args)57 int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) {
58   return std::vsnprintf(error_message, sizeof(error_message), format, args);
59 }
60 
61 #if TFLITE_USE_C_API
TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)62 TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
63     : model_(nullptr, TfLiteModelDelete),
64       resolver_(std::move(resolver)) {}
65 #else
TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)66 TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
67     : model_(), resolver_(std::move(resolver)) {}
68 #endif
69 
GetInputs()70 std::vector<TfLiteTensor*> TfLiteEngine::GetInputs() {
71   Interpreter* interpreter = this->interpreter();
72   std::vector<TfLiteTensor*> tensors;
73   int input_count = InputCount(interpreter);
74   tensors.reserve(input_count);
75   for (int index = 0; index < input_count; index++) {
76     tensors.push_back(GetInput(interpreter, index));
77   }
78   return tensors;
79 }
80 
GetOutputs()81 std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
82   Interpreter* interpreter = this->interpreter();
83   std::vector<const TfLiteTensor*> tensors;
84   int output_count = OutputCount(interpreter);
85   tensors.reserve(output_count);
86   for (int index = 0; index < output_count; index++) {
87     tensors.push_back(GetOutput(interpreter, index));
88   }
89   return tensors;
90 }
91 
92 // The following function is adapted from the code in tflite::FlatBufferModel::BuildFromBuffer.
BuildModelFromBuffer(const char * buffer_data,size_t buffer_size)93 void TfLiteEngine::BuildModelFromBuffer(const char* buffer_data, size_t buffer_size) {
94 #if TFLITE_USE_C_API
95   // First verify with the base flatbuffers verifier.
96   // This verifies that the model is a valid flatbuffer model.
97   flatbuffers::Verifier base_verifier(
98       reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
99   if (!VerifyModelBuffer(base_verifier)) {
100     TF_LITE_REPORT_ERROR(&error_reporter_,
101                          "The model is not a valid Flatbuffer buffer");
102     model_ = nullptr;
103     return;
104   }
105   // Build the model.
106   model_.reset(TfLiteModelCreate(buffer_data, buffer_size));
107 #else
108   // Warning: This branch of the if-statement lacks a verification step for the model.
109   model_ = tflite::FlatBufferModel::BuildFromBuffer(
110       buffer_data, buffer_size, &error_reporter_);
111 #endif
112 }
113 
InitializeFromModelFileHandler()114 absl::Status TfLiteEngine::InitializeFromModelFileHandler() {
115   const char* buffer_data = model_file_handler_->GetFileContent().data();
116   size_t buffer_size = model_file_handler_->GetFileContent().size();
117   BuildModelFromBuffer(buffer_data, buffer_size);
118   if (model_ == nullptr) {
119     // To be replaced with a proper switch-case when TF Lite model builder
120     // returns a `TfLiteStatus` code capturing this type of error.
121     if (absl::StrContains(error_reporter_.error_message,
122                           "The model is not a valid Flatbuffer")) {
123       return CreateStatusWithPayload(
124           StatusCode::kInvalidArgument, error_reporter_.error_message,
125           TfLiteSupportStatus::kInvalidFlatBufferError);
126     } else {
127       // TODO(b/154917059): augment status with another `TfLiteStatus` code when
128       // ready. And use a new `TfLiteStatus::kCoreTfLiteError` for the TFLS
129       // code, instead of the unspecified `kError`.
130       return CreateStatusWithPayload(
131           StatusCode::kUnknown,
132           absl::StrCat(
133               "Could not build model from the provided pre-loaded flatbuffer: ",
134               error_reporter_.error_message));
135     }
136   }
137 
138   ASSIGN_OR_RETURN(
139       model_metadata_extractor_,
140       tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer(
141           buffer_data, buffer_size));
142 
143   return absl::OkStatus();
144 }
145 
BuildModelFromFlatBuffer(const char * buffer_data,size_t buffer_size)146 absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data,
147                                                     size_t buffer_size) {
148   if (model_) {
149     return CreateStatusWithPayload(StatusCode::kInternal,
150                                    "Model already built");
151   }
152   external_file_.set_file_content(std::string(buffer_data, buffer_size));
153   ASSIGN_OR_RETURN(
154       model_file_handler_,
155       ExternalFileHandler::CreateFromExternalFile(&external_file_));
156   return InitializeFromModelFileHandler();
157 }
158 
BuildModelFromFile(const std::string & file_name)159 absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) {
160   if (model_) {
161     return CreateStatusWithPayload(StatusCode::kInternal,
162                                    "Model already built");
163   }
164   external_file_.set_file_name(file_name);
165   ASSIGN_OR_RETURN(
166       model_file_handler_,
167       ExternalFileHandler::CreateFromExternalFile(&external_file_));
168   return InitializeFromModelFileHandler();
169 }
170 
BuildModelFromFileDescriptor(int file_descriptor)171 absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) {
172   if (model_) {
173     return CreateStatusWithPayload(StatusCode::kInternal,
174                                    "Model already built");
175   }
176   external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor);
177   ASSIGN_OR_RETURN(
178       model_file_handler_,
179       ExternalFileHandler::CreateFromExternalFile(&external_file_));
180   return InitializeFromModelFileHandler();
181 }
182 
BuildModelFromExternalFileProto(const ExternalFile * external_file)183 absl::Status TfLiteEngine::BuildModelFromExternalFileProto(
184     const ExternalFile* external_file) {
185   if (model_) {
186     return CreateStatusWithPayload(StatusCode::kInternal,
187                                    "Model already built");
188   }
189   ASSIGN_OR_RETURN(model_file_handler_,
190                    ExternalFileHandler::CreateFromExternalFile(external_file));
191   return InitializeFromModelFileHandler();
192 }
193 
InitInterpreter(int num_threads)194 absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
195   tflite::proto::ComputeSettings compute_settings;
196   return InitInterpreter(compute_settings, num_threads);
197 }
198 
199 #if TFLITE_USE_C_API
FindBuiltinOp(void * user_data,TfLiteBuiltinOperator builtin_op,int version)200 const TfLiteRegistration* FindBuiltinOp(void* user_data,
201                                         TfLiteBuiltinOperator builtin_op,
202                                         int version) {
203   OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
204   tflite::BuiltinOperator op = static_cast<tflite::BuiltinOperator>(builtin_op);
205   return op_resolver->FindOp(op, version);
206 }
207 
FindCustomOp(void * user_data,const char * custom_op,int version)208 const TfLiteRegistration* FindCustomOp(void* user_data, const char* custom_op,
209                                        int version) {
210   OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
211   return op_resolver->FindOp(custom_op, version);
212 }
213 #endif
214 
InitInterpreter(const tflite::proto::ComputeSettings & compute_settings,int num_threads)215 absl::Status TfLiteEngine::InitInterpreter(
216     const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
217   if (model_ == nullptr) {
218     return CreateStatusWithPayload(
219         StatusCode::kInternal,
220         "TF Lite FlatBufferModel is null. Please make sure to call one of the "
221         "BuildModelFrom methods before calling InitInterpreter.");
222   }
223 #if TFLITE_USE_C_API
224   std::function<absl::Status(TfLiteDelegate*,
225                              std::unique_ptr<Interpreter, InterpreterDeleter>*)>
226       initializer = [this, num_threads](
227           TfLiteDelegate* optional_delegate,
228           std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
229       -> absl::Status {
230     std::unique_ptr<TfLiteInterpreterOptions,
231                     void (*)(TfLiteInterpreterOptions*)>
232         options{TfLiteInterpreterOptionsCreate(),
233                 TfLiteInterpreterOptionsDelete};
234     TfLiteInterpreterOptionsSetOpResolver(options.get(), FindBuiltinOp,
235                                           FindCustomOp, resolver_.get());
236     TfLiteInterpreterOptionsSetNumThreads(options.get(), num_threads);
237     if (optional_delegate != nullptr) {
238       TfLiteInterpreterOptionsAddDelegate(options.get(), optional_delegate);
239     }
240     interpreter_out->reset(
241         TfLiteInterpreterCreateWithSelectedOps(model_.get(), options.get()));
242     if (*interpreter_out == nullptr) {
243       return CreateStatusWithPayload(
244           StatusCode::kAborted,
245           absl::StrCat("Could not build the TF Lite interpreter: "
246                        "TfLiteInterpreterCreateWithSelectedOps failed: ",
247                        error_reporter_.error_message));
248     }
249     return absl::OkStatus();
250   };
251 #else
252   auto initializer =
253       [this, num_threads](
254           std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
255       -> absl::Status {
256     if (tflite::InterpreterBuilder(*model_, *resolver_)(
257             interpreter_out, num_threads) != kTfLiteOk) {
258       return CreateStatusWithPayload(
259           StatusCode::kUnknown,
260           absl::StrCat("Could not build the TF Lite interpreter: ",
261                        error_reporter_.error_message));
262     }
263     if (*interpreter_out == nullptr) {
264       return CreateStatusWithPayload(StatusCode::kInternal,
265                                      "TF Lite interpreter is null.");
266     }
267     return absl::OkStatus();
268   };
269 #endif
270 
271   absl::Status status =
272       interpreter_.InitializeWithFallback(initializer, compute_settings);
273 
274   if (!status.ok() &&
275       !status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) {
276     status = CreateStatusWithPayload(status.code(), status.message());
277   }
278   return status;
279 }
280 
281 }  // namespace core
282 }  // namespace task
283 }  // namespace tflite
284