1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
17
18 #include <unistd.h>
19
20 #include "absl/strings/match.h"
21 #include "absl/strings/str_cat.h"
22 #include "tensorflow/lite/builtin_ops.h"
23 #include "tensorflow/lite/stderr_reporter.h"
24 #include "tensorflow_lite_support/cc/common.h"
25 #include "tensorflow_lite_support/cc/port/status_macros.h"
26 #include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
27
28 #if TFLITE_USE_C_API
29 #include "tensorflow/lite/c/c_api_experimental.h"
30 #else
31 #include "tensorflow/lite/kernels/register.h"
32 #endif
33
34 namespace tflite {
35 namespace task {
36 namespace core {
37
38 #ifdef __ANDROID__
39 // https://github.com/opencv/opencv/issues/14906
40 // "ios_base::Init" object is not a part of Android's "iostream" header (in case
41 // of clang toolchain, NDK 20).
42 //
43 // Ref1:
44 // https://en.cppreference.com/w/cpp/io/ios_base/Init
45 // The header <iostream> behaves as if it defines (directly or indirectly)
46 // an instance of std::ios_base::Init with static storage duration
47 //
48 // Ref2:
49 // https://github.com/gcc-mirror/gcc/blob/gcc-8-branch/libstdc%2B%2B-v3/include/std/iostream#L73-L74
50 static std::ios_base::Init s_iostream_initializer;
51 #endif
52
53 using ::absl::StatusCode;
54 using ::tflite::support::CreateStatusWithPayload;
55 using ::tflite::support::TfLiteSupportStatus;
56
Report(const char * format,va_list args)57 int TfLiteEngine::ErrorReporter::Report(const char* format, va_list args) {
58 return std::vsnprintf(error_message, sizeof(error_message), format, args);
59 }
60
61 #if TFLITE_USE_C_API
TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)62 TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
63 : model_(nullptr, TfLiteModelDelete),
64 resolver_(std::move(resolver)) {}
65 #else
TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)66 TfLiteEngine::TfLiteEngine(std::unique_ptr<tflite::OpResolver> resolver)
67 : model_(), resolver_(std::move(resolver)) {}
68 #endif
69
GetInputs()70 std::vector<TfLiteTensor*> TfLiteEngine::GetInputs() {
71 Interpreter* interpreter = this->interpreter();
72 std::vector<TfLiteTensor*> tensors;
73 int input_count = InputCount(interpreter);
74 tensors.reserve(input_count);
75 for (int index = 0; index < input_count; index++) {
76 tensors.push_back(GetInput(interpreter, index));
77 }
78 return tensors;
79 }
80
GetOutputs()81 std::vector<const TfLiteTensor*> TfLiteEngine::GetOutputs() {
82 Interpreter* interpreter = this->interpreter();
83 std::vector<const TfLiteTensor*> tensors;
84 int output_count = OutputCount(interpreter);
85 tensors.reserve(output_count);
86 for (int index = 0; index < output_count; index++) {
87 tensors.push_back(GetOutput(interpreter, index));
88 }
89 return tensors;
90 }
91
92 // The following function is adapted from the code in tflite::FlatBufferModel::BuildFromBuffer.
BuildModelFromBuffer(const char * buffer_data,size_t buffer_size)93 void TfLiteEngine::BuildModelFromBuffer(const char* buffer_data, size_t buffer_size) {
94 #if TFLITE_USE_C_API
95 // First verify with the base flatbuffers verifier.
96 // This verifies that the model is a valid flatbuffer model.
97 flatbuffers::Verifier base_verifier(
98 reinterpret_cast<const uint8_t*>(buffer_data), buffer_size);
99 if (!VerifyModelBuffer(base_verifier)) {
100 TF_LITE_REPORT_ERROR(&error_reporter_,
101 "The model is not a valid Flatbuffer buffer");
102 model_ = nullptr;
103 return;
104 }
105 // Build the model.
106 model_.reset(TfLiteModelCreate(buffer_data, buffer_size));
107 #else
108 // Warning: This branch of the if-statement lacks a verification step for the model.
109 model_ = tflite::FlatBufferModel::BuildFromBuffer(
110 buffer_data, buffer_size, &error_reporter_);
111 #endif
112 }
113
InitializeFromModelFileHandler()114 absl::Status TfLiteEngine::InitializeFromModelFileHandler() {
115 const char* buffer_data = model_file_handler_->GetFileContent().data();
116 size_t buffer_size = model_file_handler_->GetFileContent().size();
117 BuildModelFromBuffer(buffer_data, buffer_size);
118 if (model_ == nullptr) {
119 // To be replaced with a proper switch-case when TF Lite model builder
120 // returns a `TfLiteStatus` code capturing this type of error.
121 if (absl::StrContains(error_reporter_.error_message,
122 "The model is not a valid Flatbuffer")) {
123 return CreateStatusWithPayload(
124 StatusCode::kInvalidArgument, error_reporter_.error_message,
125 TfLiteSupportStatus::kInvalidFlatBufferError);
126 } else {
127 // TODO(b/154917059): augment status with another `TfLiteStatus` code when
128 // ready. And use a new `TfLiteStatus::kCoreTfLiteError` for the TFLS
129 // code, instead of the unspecified `kError`.
130 return CreateStatusWithPayload(
131 StatusCode::kUnknown,
132 absl::StrCat(
133 "Could not build model from the provided pre-loaded flatbuffer: ",
134 error_reporter_.error_message));
135 }
136 }
137
138 ASSIGN_OR_RETURN(
139 model_metadata_extractor_,
140 tflite::metadata::ModelMetadataExtractor::CreateFromModelBuffer(
141 buffer_data, buffer_size));
142
143 return absl::OkStatus();
144 }
145
BuildModelFromFlatBuffer(const char * buffer_data,size_t buffer_size)146 absl::Status TfLiteEngine::BuildModelFromFlatBuffer(const char* buffer_data,
147 size_t buffer_size) {
148 if (model_) {
149 return CreateStatusWithPayload(StatusCode::kInternal,
150 "Model already built");
151 }
152 external_file_.set_file_content(std::string(buffer_data, buffer_size));
153 ASSIGN_OR_RETURN(
154 model_file_handler_,
155 ExternalFileHandler::CreateFromExternalFile(&external_file_));
156 return InitializeFromModelFileHandler();
157 }
158
BuildModelFromFile(const std::string & file_name)159 absl::Status TfLiteEngine::BuildModelFromFile(const std::string& file_name) {
160 if (model_) {
161 return CreateStatusWithPayload(StatusCode::kInternal,
162 "Model already built");
163 }
164 external_file_.set_file_name(file_name);
165 ASSIGN_OR_RETURN(
166 model_file_handler_,
167 ExternalFileHandler::CreateFromExternalFile(&external_file_));
168 return InitializeFromModelFileHandler();
169 }
170
BuildModelFromFileDescriptor(int file_descriptor)171 absl::Status TfLiteEngine::BuildModelFromFileDescriptor(int file_descriptor) {
172 if (model_) {
173 return CreateStatusWithPayload(StatusCode::kInternal,
174 "Model already built");
175 }
176 external_file_.mutable_file_descriptor_meta()->set_fd(file_descriptor);
177 ASSIGN_OR_RETURN(
178 model_file_handler_,
179 ExternalFileHandler::CreateFromExternalFile(&external_file_));
180 return InitializeFromModelFileHandler();
181 }
182
BuildModelFromExternalFileProto(const ExternalFile * external_file)183 absl::Status TfLiteEngine::BuildModelFromExternalFileProto(
184 const ExternalFile* external_file) {
185 if (model_) {
186 return CreateStatusWithPayload(StatusCode::kInternal,
187 "Model already built");
188 }
189 ASSIGN_OR_RETURN(model_file_handler_,
190 ExternalFileHandler::CreateFromExternalFile(external_file));
191 return InitializeFromModelFileHandler();
192 }
193
InitInterpreter(int num_threads)194 absl::Status TfLiteEngine::InitInterpreter(int num_threads) {
195 tflite::proto::ComputeSettings compute_settings;
196 return InitInterpreter(compute_settings, num_threads);
197 }
198
199 #if TFLITE_USE_C_API
FindBuiltinOp(void * user_data,TfLiteBuiltinOperator builtin_op,int version)200 const TfLiteRegistration* FindBuiltinOp(void* user_data,
201 TfLiteBuiltinOperator builtin_op,
202 int version) {
203 OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
204 tflite::BuiltinOperator op = static_cast<tflite::BuiltinOperator>(builtin_op);
205 return op_resolver->FindOp(op, version);
206 }
207
FindCustomOp(void * user_data,const char * custom_op,int version)208 const TfLiteRegistration* FindCustomOp(void* user_data, const char* custom_op,
209 int version) {
210 OpResolver* op_resolver = reinterpret_cast<OpResolver*>(user_data);
211 return op_resolver->FindOp(custom_op, version);
212 }
213 #endif
214
InitInterpreter(const tflite::proto::ComputeSettings & compute_settings,int num_threads)215 absl::Status TfLiteEngine::InitInterpreter(
216 const tflite::proto::ComputeSettings& compute_settings, int num_threads) {
217 if (model_ == nullptr) {
218 return CreateStatusWithPayload(
219 StatusCode::kInternal,
220 "TF Lite FlatBufferModel is null. Please make sure to call one of the "
221 "BuildModelFrom methods before calling InitInterpreter.");
222 }
223 #if TFLITE_USE_C_API
224 std::function<absl::Status(TfLiteDelegate*,
225 std::unique_ptr<Interpreter, InterpreterDeleter>*)>
226 initializer = [this, num_threads](
227 TfLiteDelegate* optional_delegate,
228 std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
229 -> absl::Status {
230 std::unique_ptr<TfLiteInterpreterOptions,
231 void (*)(TfLiteInterpreterOptions*)>
232 options{TfLiteInterpreterOptionsCreate(),
233 TfLiteInterpreterOptionsDelete};
234 TfLiteInterpreterOptionsSetOpResolver(options.get(), FindBuiltinOp,
235 FindCustomOp, resolver_.get());
236 TfLiteInterpreterOptionsSetNumThreads(options.get(), num_threads);
237 if (optional_delegate != nullptr) {
238 TfLiteInterpreterOptionsAddDelegate(options.get(), optional_delegate);
239 }
240 interpreter_out->reset(
241 TfLiteInterpreterCreateWithSelectedOps(model_.get(), options.get()));
242 if (*interpreter_out == nullptr) {
243 return CreateStatusWithPayload(
244 StatusCode::kAborted,
245 absl::StrCat("Could not build the TF Lite interpreter: "
246 "TfLiteInterpreterCreateWithSelectedOps failed: ",
247 error_reporter_.error_message));
248 }
249 return absl::OkStatus();
250 };
251 #else
252 auto initializer =
253 [this, num_threads](
254 std::unique_ptr<Interpreter, InterpreterDeleter>* interpreter_out)
255 -> absl::Status {
256 if (tflite::InterpreterBuilder(*model_, *resolver_)(
257 interpreter_out, num_threads) != kTfLiteOk) {
258 return CreateStatusWithPayload(
259 StatusCode::kUnknown,
260 absl::StrCat("Could not build the TF Lite interpreter: ",
261 error_reporter_.error_message));
262 }
263 if (*interpreter_out == nullptr) {
264 return CreateStatusWithPayload(StatusCode::kInternal,
265 "TF Lite interpreter is null.");
266 }
267 return absl::OkStatus();
268 };
269 #endif
270
271 absl::Status status =
272 interpreter_.InitializeWithFallback(initializer, compute_settings);
273
274 if (!status.ok() &&
275 !status.GetPayload(tflite::support::kTfLiteSupportPayload).has_value()) {
276 status = CreateStatusWithPayload(status.code(), status.message());
277 }
278 return status;
279 }
280
281 } // namespace core
282 } // namespace task
283 } // namespace tflite
284