xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/interpreter_builder.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/interpreter_builder.h"
16 
17 #include <stddef.h>
18 #include <stdint.h>
19 #include <stdlib.h>
20 #include <string.h>
21 
22 #include <algorithm>
23 #include <map>
24 #include <memory>
25 #include <string>
26 #include <utility>
27 #include <vector>
28 
29 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
30 #include "tensorflow/lite/c/c_api_types.h"
31 #include "tensorflow/lite/core/api/error_reporter.h"
32 #include "tensorflow/lite/core/api/flatbuffer_conversions.h"
33 #include "tensorflow/lite/core/api/op_resolver.h"
34 #include "tensorflow/lite/core/macros.h"
35 #include "tensorflow/lite/core/subgraph.h"
36 #include "tensorflow/lite/internal/signature_def.h"
37 #include "tensorflow/lite/interpreter.h"
38 #include "tensorflow/lite/kernels/internal/compatibility.h"
39 #include "tensorflow/lite/model_builder.h"
40 #include "tensorflow/lite/profiling/platform_profiler.h"
41 #include "tensorflow/lite/schema/schema_generated.h"
42 #include "tensorflow/lite/schema/schema_utils.h"
43 #include "tensorflow/lite/shared_library.h"
44 #include "tensorflow/lite/stderr_reporter.h"
45 #include "tensorflow/lite/string_type.h"
46 #include "tensorflow/lite/util.h"
47 #include "tensorflow/lite/version.h"
48 
49 // aligned_alloc is available (via cstdlib/stdlib.h) with C++17/C11.
50 #if __cplusplus >= 201703L || __STDC_VERSION__ >= 201112L
51 #if !defined(__ANDROID__) || __ANDROID_API__ >= 28
52 // Neither Apple nor Windows provide aligned_alloc.
53 #if !defined(__APPLE__) && !defined(_WIN32)
54 #define TFLITE_USE_STD_ALIGNED_ALLOC
55 #endif
56 #endif
57 #endif
58 
59 // TODO(b/139446230): Move to portable platform header.
60 #if defined(__ANDROID__)
61 #define TFLITE_IS_MOBILE_PLATFORM
62 #endif  // defined(__ANDROID__)
63 
64 #if defined(__APPLE__)
65 #include "TargetConditionals.h"
66 #if TARGET_IPHONE_SIMULATOR
67 #define TFLITE_IS_MOBILE_PLATFORM
68 #elif TARGET_OS_IPHONE
69 #define TFLITE_IS_MOBILE_PLATFORM
70 #endif
71 #endif  // defined(__APPLE__)
72 
73 namespace tflite {
74 
75 namespace {
76 
77 // Ensure that ErrorReporter is non-null.
ValidateErrorReporter(ErrorReporter * e)78 ErrorReporter* ValidateErrorReporter(ErrorReporter* e) {
79   return e ? e : DefaultErrorReporter();
80 }
81 
82 template <typename T>
Copy(const T * data_ptr,TfLiteIntArray ** arr)83 TfLiteStatus Copy(const T* data_ptr, TfLiteIntArray** arr) {
84   if (data_ptr->values() == nullptr) {
85     return kTfLiteError;
86   }
87 
88   int size = data_ptr->values()->size();
89   *arr = TfLiteIntArrayCreate(size);
90   for (int i = 0; i < size; i++) {
91     (*arr)->data[i] = static_cast<int>(data_ptr->values()->Get(i));
92   }
93   return kTfLiteOk;
94 }
95 
ParseSparseIndexVector(const DimensionMetadata * src,TfLiteDimensionMetadata * tgt)96 TfLiteStatus ParseSparseIndexVector(const DimensionMetadata* src,
97                                     TfLiteDimensionMetadata* tgt) {
98   if (src->array_segments() == nullptr || src->array_indices() == nullptr) {
99     return kTfLiteError;
100   }
101   TfLiteStatus status = kTfLiteOk;
102   switch (src->array_segments_type()) {
103     case SparseIndexVector_Int32Vector:
104       status = Copy(src->array_segments_as_Int32Vector(), &tgt->array_segments);
105       break;
106     case SparseIndexVector_Uint16Vector:
107       status =
108           Copy(src->array_segments_as_Uint16Vector(), &tgt->array_segments);
109       break;
110     case SparseIndexVector_Uint8Vector:
111       status = Copy(src->array_segments_as_Uint8Vector(), &tgt->array_segments);
112       break;
113     default:
114       status = kTfLiteError;
115       break;
116   }
117   if (status != kTfLiteOk) return status;
118 
119   switch (src->array_indices_type()) {
120     case SparseIndexVector_Int32Vector:
121       return Copy(src->array_indices_as_Int32Vector(), &tgt->array_indices);
122     case SparseIndexVector_Uint16Vector:
123       return Copy(src->array_indices_as_Uint16Vector(), &tgt->array_indices);
124     case SparseIndexVector_Uint8Vector:
125       return Copy(src->array_indices_as_Uint8Vector(), &tgt->array_indices);
126     default:
127       break;
128   }
129   return kTfLiteError;
130 }
131 
132 // Helper that returns std::map that corresponds to vector of TensorMap.
GetMapFromTensorMap(const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>> * tensor_map)133 std::map<std::string, uint32_t> GetMapFromTensorMap(
134     const flatbuffers::Vector<flatbuffers::Offset<tflite::TensorMap>>*
135         tensor_map) {
136   if (!tensor_map) return {};
137   std::map<std::string, uint32_t> result;
138   for (const auto tensor : *tensor_map) {
139     if (tensor != nullptr && tensor->name() != nullptr) {
140       result[tensor->name()->c_str()] = tensor->tensor_index();
141     }
142   }
143   return result;
144 }
145 
ShouldCreateLazyDelegateProviders(int num_fp32_tensors)146 inline bool ShouldCreateLazyDelegateProviders(int num_fp32_tensors) {
147 #if defined(XNNPACK_DELEGATE_ENABLE_QS8) || defined(XNNPACK_DELEGATE_ENABLE_QU8)
148   return true;
149 #else
150   return num_fp32_tensors > 0;
151 #endif
152 }
153 
154 }  // namespace
155 
156 constexpr const char* kEmptyTensorName = "";
157 
158 // Using weak symbols to create a delegate allows automatic injection of the
159 // delegate simply by adding it as a dependency.
160 // For flex delegate, see also the strong override in
161 // lite/delegates/flex/delegate.cc.
AcquireFlexDelegate()162 TFLITE_ATTRIBUTE_WEAK Interpreter::TfLiteDelegatePtr AcquireFlexDelegate() {
163   // TF_AcquireFlexDelegate isn't defined on Android, and the following block of
164   // code would have no effect if TF_AcquireFlexDelegate isn't defined, so we
165   // only enable that block for non-Android platforms.  Also, on Android 4.4
166   // (Kitkat), the dlsym() implementation has a bug where dlsym() of an unknown
167   // name will result in a SIGFPE, which would crash the process, so it's
168   // important that on Android 4.4 we *don't* call SharedLibrary::GetSymbol
169   // unless the symbol is sure to exist.
170 #if !defined(__ANDROID__)
171   auto acquire_flex_delegate_func =
172       reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
173           SharedLibrary::GetSymbol("TF_AcquireFlexDelegate"));
174   if (acquire_flex_delegate_func) {
175     return acquire_flex_delegate_func();
176   }
177 #endif
178 
179 #if !defined(TFLITE_IS_MOBILE_PLATFORM)
180   // Load TF_AcquireFlexDelegate() from _pywrap_tensorflow_internal.so if it is
181   // available.
182 #if defined(_WIN32)
183   const wchar_t* filename_pywrap_tensorflow_internal =
184       L"_pywrap_tensorflow_internal.pyd";
185 #elif defined(__APPLE__)
186   const char* filename_pywrap_tensorflow_internal =
187       "python/_pywrap_tensorflow_internal.so";
188 #else
189   const char* filename_pywrap_tensorflow_internal =
190       "_pywrap_tensorflow_internal.so";
191 #endif
192   void* lib_tf_internal =
193       SharedLibrary::LoadLibrary(filename_pywrap_tensorflow_internal);
194 #if defined(_WIN32)
195   if (lib_tf_internal == nullptr) {
196     lib_tf_internal = SharedLibrary::LoadLibrary(
197         L"_pywrap_tensorflow_interpreter_wrapper.pyd");
198   }
199 #endif
200   if (lib_tf_internal) {
201     acquire_flex_delegate_func =
202         reinterpret_cast<Interpreter::TfLiteDelegatePtr (*)()>(
203             SharedLibrary::GetLibrarySymbol(lib_tf_internal,
204                                             "TF_AcquireFlexDelegate"));
205     if (acquire_flex_delegate_func) {
206       return acquire_flex_delegate_func();
207     }
208   }
209 #endif  // !defined(TFLITE_IS_MOBILE_PLATFORM)
210 
211   return Interpreter::TfLiteDelegatePtr(nullptr, [](TfLiteDelegate*) {});
212 }
213 
InterpreterBuilder(const FlatBufferModel & model,const OpResolver & op_resolver,const InterpreterOptions * options_experimental)214 InterpreterBuilder::InterpreterBuilder(
215     const FlatBufferModel& model, const OpResolver& op_resolver,
216     const InterpreterOptions* options_experimental)
217     : model_(model.GetModel()),
218       op_resolver_(op_resolver),
219       error_reporter_(ValidateErrorReporter(model.error_reporter())),
220       metadata_(model.ReadAllMetadata()),
221       allocation_(model.allocation()) {
222   if (options_experimental) {
223     options_ = *options_experimental;
224   }
225 }
226 
InterpreterBuilder(const::tflite::Model * model,const OpResolver & op_resolver,ErrorReporter * error_reporter,const InterpreterOptions * options_experimental)227 InterpreterBuilder::InterpreterBuilder(
228     const ::tflite::Model* model, const OpResolver& op_resolver,
229     ErrorReporter* error_reporter,
230     const InterpreterOptions* options_experimental)
231     : model_(model),
232       op_resolver_(op_resolver),
233       error_reporter_(ValidateErrorReporter(error_reporter)) {
234   if (options_experimental) {
235     options_ = *options_experimental;
236   }
237 }
238 
~InterpreterBuilder()239 InterpreterBuilder::~InterpreterBuilder() {}
240 
BuildLocalIndexToRegistrationMapping()241 TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
242   TfLiteStatus status = kTfLiteOk;
243   // Reset state.
244   flatbuffer_op_index_to_registration_.clear();
245   unresolved_custom_ops_.clear();
246 
247   auto opcodes = model_->operator_codes();
248   if (!opcodes) {
249     return status;
250   }
251   int num_custom_ops = 0;
252   for (const OperatorCode* opcode : *opcodes) {
253     if (GetBuiltinCode(opcode) == BuiltinOperator_CUSTOM) {
254       num_custom_ops++;
255     }
256   }
257   unresolved_custom_ops_.reserve(num_custom_ops);
258   for (const OperatorCode* opcode : *opcodes) {
259     const TfLiteRegistration* registration = nullptr;
260     status = GetRegistrationFromOpCode(opcode, op_resolver_, error_reporter_,
261                                        &registration);
262     if (status != kTfLiteOk) {
263       if (GetBuiltinCode(opcode) != BuiltinOperator_CUSTOM) {
264         return status;
265       }
266       // If it's an unresolved custom op, allow it for now. It might be resolved
267       // by a delegate later.
268       if (!opcode->custom_code()) {
269         error_reporter_->Report(
270             "Operator with CUSTOM builtin_code has no custom_code.\n");
271         return status;
272       }
273       const auto* op_name = opcode->custom_code()->c_str();
274       unresolved_custom_ops_.push_back(CreateUnresolvedCustomOp(op_name));
275       registration = &unresolved_custom_ops_.back();
276       has_flex_op_ |= IsFlexOp(op_name);
277       status = kTfLiteOk;
278     }
279     flatbuffer_op_index_to_registration_.push_back(registration);
280   }
281   return status;
282 }
283 
284 namespace {
285 template <class T>
FlatBufferIntArrayToVector(T * flat_array)286 std::vector<int> FlatBufferIntArrayToVector(T* flat_array) {
287   // Initialize shape of tensors with null shape. Empty vectors are converted
288   // to nullptr for models that are constructed via flatbuffers::Pack.
289   if (flat_array == nullptr) {
290     return {};
291   }
292   std::vector<int> ret(flat_array->size());
293   for (int i = 0; i < flat_array->size(); i++) {
294     ret[i] = flat_array->Get(i);
295   }
296   return ret;
297 }
298 
299 // Used to determine how the op data parsing function creates its working space.
300 class MallocDataAllocator : public BuiltinDataAllocator {
301  public:
Allocate(size_t size,size_t alignment_hint)302   void* Allocate(size_t size, size_t alignment_hint) override {
303 #ifdef TFLITE_USE_STD_ALIGNED_ALLOC
304     // Ensure that alignment is a power of two and a multiple of sizeof(void *)
305     // and that size is an integral multiple of alignment.
306     size_t used_alignment = std::max(alignment_hint, sizeof(void*));
307     size_t used_size =
308         ((size + used_alignment - 1) / used_alignment) * used_alignment;
309     TFLITE_DCHECK(
310         (used_alignment != 0) &&
311         ((used_alignment & (used_alignment - 1)) == 0));  // is power-of-two
312     return aligned_alloc(used_alignment, used_size);
313 #else
314     return malloc(size);
315 #endif
316   }
Deallocate(void * data)317   void Deallocate(void* data) override { free(data); }
318 };
319 
320 }  // namespace
321 
ParseNodes(const flatbuffers::Vector<flatbuffers::Offset<Operator>> * operators,Subgraph * subgraph)322 TfLiteStatus InterpreterBuilder::ParseNodes(
323     const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
324     Subgraph* subgraph) {
325   TfLiteStatus status = kTfLiteOk;
326 
327   // Reduce the number of redundant allocations
328   subgraph->ReserveNodes(operators->size());
329 
330   for (int i = 0; i < operators->size(); ++i) {
331     const auto* op = operators->Get(i);
332     int index = op->opcode_index();
333     if (index < 0 || index >= flatbuffer_op_index_to_registration_.size()) {
334       error_reporter_->Report("Missing registration for opcode_index %d\n",
335                               index);
336       status = kTfLiteError;
337       continue;
338     }
339 
340     const TfLiteRegistration* registration =
341         flatbuffer_op_index_to_registration_[index];
342     if (registration == nullptr) {
343       error_reporter_->Report("Skipping op for opcode_index %d\n", index);
344       status = kTfLiteError;
345       continue;
346     }
347 
348     BuiltinOperator op_type =
349         static_cast<BuiltinOperator>(registration->builtin_code);
350 
351     if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
352       error_reporter_->Report(
353           "Found builtin operator %s with custom options.\n",
354           EnumNameBuiltinOperator(op_type));
355     }
356 
357     if (op_type == BuiltinOperator_CUSTOM) {
358       if (op->custom_options()) {
359         subgraph->AddNodeWithParameters(
360             FlatBufferIntArrayToVector(op->inputs()),
361             FlatBufferIntArrayToVector(op->outputs()),
362             FlatBufferIntArrayToVector(op->intermediates()),
363             reinterpret_cast<const char*>(op->custom_options()->data()),
364             op->custom_options()->size(), nullptr, registration);
365       } else {
366         subgraph->AddNodeWithParameters(
367             FlatBufferIntArrayToVector(op->inputs()),
368             FlatBufferIntArrayToVector(op->outputs()),
369             FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
370             nullptr, registration);
371       }
372     } else {
373       void* builtin_data = nullptr;
374       MallocDataAllocator malloc_allocator;
375       TF_LITE_ENSURE_STATUS(ParseOpData(op, op_type, error_reporter_,
376                                         &malloc_allocator, &builtin_data));
377       subgraph->AddNodeWithParameters(
378           FlatBufferIntArrayToVector(op->inputs()),
379           FlatBufferIntArrayToVector(op->outputs()),
380           FlatBufferIntArrayToVector(op->intermediates()), nullptr, 0,
381           builtin_data, registration);
382     }
383   }
384 
385   return status;
386 }
387 
ParseQuantization(const QuantizationParameters * src_quantization,TfLiteQuantization * quantization,const std::vector<int> & dims)388 TfLiteStatus InterpreterBuilder::ParseQuantization(
389     const QuantizationParameters* src_quantization,
390     TfLiteQuantization* quantization, const std::vector<int>& dims) {
391   quantization->type = kTfLiteNoQuantization;
392   if (!src_quantization || !src_quantization->scale() ||
393       src_quantization->scale()->size() == 0) {
394     return kTfLiteOk;
395   }
396   if (!src_quantization->zero_point()) {
397     error_reporter_->Report(
398         "Quantization parameters has non-null scale but null zero_point.");
399     return kTfLiteError;
400   }
401 
402   // Ensure that the number of scales matches the number of zero_points.
403   if (src_quantization->scale()->size() !=
404       src_quantization->zero_point()->size()) {
405     error_reporter_->Report(
406         "QuantizationParam has %d zero_point values and %d scale values. Must "
407         "have same number.",
408         src_quantization->zero_point()->size(),
409         src_quantization->scale()->size());
410     return kTfLiteError;
411   }
412 
413   const size_t num_scales = src_quantization->scale()->size();
414 
415   // Ensure that the quantization dimension is valid.
416   if (src_quantization->quantized_dimension() < 0 ||
417       (!dims.empty() &&
418        src_quantization->quantized_dimension() >= dims.size())) {
419     error_reporter_->Report(
420         "quantized_dimension must be in range [0, %d). Was %d.", dims.size(),
421         src_quantization->quantized_dimension());
422     return kTfLiteError;
423   }
424 
425   // Ensure that the number of scales is 1 for per-layer quantization, and
426   // matches number of quantization dimensions for per-axis quantization.
427   if (num_scales != 1 &&
428       (!dims.empty() &&
429        num_scales != dims[src_quantization->quantized_dimension()])) {
430     error_reporter_->Report(
431         "num_scales must be 1 for per-layer quantization, or %d for per-axis "
432         "quantization, but got %d.",
433         dims[src_quantization->quantized_dimension()], num_scales);
434     return kTfLiteError;
435   }
436 
437   // Affine-quantization.
438   quantization->type = kTfLiteAffineQuantization;
439   auto* affine_quantization = reinterpret_cast<TfLiteAffineQuantization*>(
440       malloc(sizeof(TfLiteAffineQuantization)));
441   affine_quantization->scale = TfLiteFloatArrayCreate(num_scales);
442   affine_quantization->zero_point = TfLiteIntArrayCreate(num_scales);
443   for (size_t i = 0; i < num_scales; ++i) {
444     affine_quantization->scale->data[i] = src_quantization->scale()->Get(i);
445     affine_quantization->zero_point->data[i] =
446         src_quantization->zero_point()->Get(i);
447   }
448   affine_quantization->quantized_dimension =
449       src_quantization->quantized_dimension();
450   quantization->params = reinterpret_cast<void*>(affine_quantization);
451   return kTfLiteOk;
452 }
453 
ParseSparsity(const SparsityParameters * src_sparsity,TfLiteSparsity ** sparsity_ptr)454 TfLiteStatus InterpreterBuilder::ParseSparsity(
455     const SparsityParameters* src_sparsity, TfLiteSparsity** sparsity_ptr) {
456   if (!src_sparsity) {
457     return kTfLiteOk;
458   }
459 
460   if (src_sparsity->traversal_order() == nullptr ||
461       src_sparsity->dim_metadata() == nullptr) {
462     error_reporter_->Report("Invalid sparsity parameter.");
463     return kTfLiteError;
464   }
465 
466   auto* sparsity =
467       reinterpret_cast<TfLiteSparsity*>(malloc(sizeof(TfLiteSparsity)));
468   memset(sparsity, 0, sizeof(TfLiteSparsity));
469   *sparsity_ptr = sparsity;
470 
471   const size_t traversal_order_size = src_sparsity->traversal_order()->size();
472   sparsity->traversal_order = TfLiteIntArrayCreate(traversal_order_size);
473   for (int i = 0; i < traversal_order_size; i++) {
474     sparsity->traversal_order->data[i] =
475         src_sparsity->traversal_order()->Get(i);
476   }
477 
478   if (src_sparsity->block_map()) {
479     const size_t block_map_size = src_sparsity->block_map()->size();
480     sparsity->block_map = TfLiteIntArrayCreate(block_map_size);
481     for (int i = 0; i < block_map_size; i++) {
482       sparsity->block_map->data[i] = src_sparsity->block_map()->Get(i);
483     }
484   }
485 
486   const size_t dim_metadata_size = src_sparsity->dim_metadata()->size();
487   sparsity->dim_metadata_size = dim_metadata_size;
488   sparsity->dim_metadata = reinterpret_cast<TfLiteDimensionMetadata*>(
489       malloc(dim_metadata_size * sizeof(TfLiteDimensionMetadata)));
490   memset(sparsity->dim_metadata, 0,
491          dim_metadata_size * sizeof(TfLiteDimensionMetadata));
492 
493   for (int i = 0; i < dim_metadata_size; i++) {
494     const auto* src_metadata = src_sparsity->dim_metadata()->Get(i);
495     if (src_metadata->format() != DimensionType_DENSE &&
496         src_metadata->format() != DimensionType_SPARSE_CSR) {
497       TF_LITE_REPORT_ERROR(error_reporter_,
498                            "The %dth dimension has unknown type: %d.", i,
499                            src_metadata->format());
500       return kTfLiteError;
501     }
502     auto* tgt_metadata = &sparsity->dim_metadata[i];
503 
504     tgt_metadata->format =
505         static_cast<TfLiteDimensionType>(src_metadata->format());
506 
507     if (tgt_metadata->format == kTfLiteDimDense) {
508       tgt_metadata->dense_size = src_metadata->dense_size();
509     } else {
510       if (ParseSparseIndexVector(src_metadata, tgt_metadata) != kTfLiteOk) {
511         TF_LITE_REPORT_ERROR(
512             error_reporter_,
513             "The %dth sparse dimension has invalid parameters.", i);
514         return kTfLiteError;
515       }
516     }
517   }
518 
519   return kTfLiteOk;
520 }
521 
ParseSignatureDefs(const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>> * signature_def_list,Interpreter * interpreter)522 TfLiteStatus InterpreterBuilder::ParseSignatureDefs(
523     const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
524         signature_def_list,
525     Interpreter* interpreter) {
526   if (signature_def_list == nullptr || signature_def_list->size() == 0) {
527     return kTfLiteOk;
528   }
529   std::vector<internal::SignatureDef> signature_defs;
530   signature_defs.reserve(signature_def_list->size());
531   for (const auto fb_signature_def : *signature_def_list) {
532     if (fb_signature_def == nullptr) {
533       TF_LITE_REPORT_ERROR(error_reporter_, "NULL SignatureDef in the model.");
534       return kTfLiteError;
535     }
536     if (fb_signature_def->signature_key() == nullptr) {
537       TF_LITE_REPORT_ERROR(error_reporter_,
538                            "Missing exported method name for SignatureDef");
539       return kTfLiteError;
540     }
541     if (fb_signature_def->inputs() == nullptr) {
542       TF_LITE_REPORT_ERROR(error_reporter_,
543                            "NULL SignatureDef inputs for exported method %s",
544                            fb_signature_def->signature_key()->c_str());
545       return kTfLiteError;
546     }
547     if (fb_signature_def->outputs() == nullptr) {
548       TF_LITE_REPORT_ERROR(error_reporter_,
549                            "NULL SignatureDef outputs for exported method %s",
550                            fb_signature_def->signature_key()->c_str());
551       return kTfLiteError;
552     }
553     signature_defs.resize(signature_defs.size() + 1);
554     auto& signature_def = signature_defs.back();
555     signature_def.inputs = GetMapFromTensorMap(fb_signature_def->inputs());
556     signature_def.outputs = GetMapFromTensorMap(fb_signature_def->outputs());
557     signature_def.signature_key = fb_signature_def->signature_key()->c_str();
558     signature_def.subgraph_index = fb_signature_def->subgraph_index();
559   }
560   interpreter->SetSignatureDef(std::move(signature_defs));
561   return kTfLiteOk;
562 }
563 
ParseTensors(const flatbuffers::Vector<flatbuffers::Offset<Buffer>> * buffers,const flatbuffers::Vector<flatbuffers::Offset<Tensor>> * tensors,Subgraph * subgraph)564 TfLiteStatus InterpreterBuilder::ParseTensors(
565     const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
566     const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
567     Subgraph* subgraph) {
568   TfLiteStatus status = kTfLiteOk;
569 
570   // A little helper to get the names of inputs and outputs. Note that they
571   // must outlive the subgraph.
572   auto get_name = [](const tflite::Tensor* t) -> const char* {
573     auto name = t->name();
574     if (name) return name->c_str();
575     return kEmptyTensorName;
576   };
577 
578   num_fp32_tensors_ = 0;
579   for (int i = 0; i < tensors->size(); ++i) {
580     const auto* tensor = tensors->Get(i);
581     std::vector<int> dims = FlatBufferIntArrayToVector(tensor->shape());
582 
583     TfLiteType type;
584     if (ConvertTensorType(tensor->type(), &type, error_reporter_) !=
585         kTfLiteOk) {
586       status = kTfLiteError;
587       continue;
588     }
589     if (type == kTfLiteFloat32) {
590       ++num_fp32_tensors_;
591     }
592     auto get_readonly_data = [&](const char** buffer_data,
593                                  size_t* buffer_size) {
594       // TODO(aselle): Check what happens if we have an unspecified size
595       // constant.
596       *buffer_data = nullptr;
597       if (tensor->buffer() == 0) return kTfLiteOk;
598       if (tensor->buffer() >= buffers->size()) {
599         error_reporter_->Report(
600             "Tensor %d specifies out of range buffer %d (only %d buffers).\n",
601             i, tensor->buffer(), buffers->size());
602         return kTfLiteError;
603       }
604       if (auto* buffer = (*buffers)[tensor->buffer()]) {
605         if (auto* array = buffer->data()) {
606           *buffer_size = array->size();
607           *buffer_data = reinterpret_cast<const char*>(array->data());
608           return kTfLiteOk;
609         }
610       }
611       return kTfLiteOk;
612     };
613     size_t buffer_size = 0;
614     const char* buffer_ptr;
615     TF_LITE_ENSURE_STATUS(get_readonly_data(&buffer_ptr, &buffer_size));
616 
617     const auto* src_quantization = tensor->quantization();
618     TfLiteQuantization quantization;
619     if (ParseQuantization(src_quantization, &quantization, dims) != kTfLiteOk) {
620       error_reporter_->Report("Tensor %d has invalid quantization parameters.",
621                               i);
622       status = kTfLiteError;
623     }
624 
625     std::vector<int> dims_signature = {};
626     if (tensor->shape_signature()) {
627       dims_signature = FlatBufferIntArrayToVector(tensor->shape_signature());
628     }
629 
630     bool is_variable = tensor->is_variable();
631     if (buffer_ptr) {
632       if (is_variable) {
633         error_reporter_->Report(
634             "Tensor %d is a variable tensor with buffer. "
635             "It's not supported now.\n",
636             i);
637         status = kTfLiteError;
638       }
639 
640       // TODO(b/144999664): Only constant sparse tensor is supported now.
641       const auto* src_sparsity = tensor->sparsity();
642       TfLiteSparsity* sparsity = nullptr;
643       if (ParseSparsity(src_sparsity, &sparsity) != kTfLiteOk) {
644         error_reporter_->Report("Tensor %d has invalid sparsity parameters.",
645                                 i);
646         status = kTfLiteError;
647       }
648 
649       if (subgraph->SetTensorParametersReadOnly(
650               i, type, get_name(tensor), dims, quantization, buffer_ptr,
651               buffer_size, allocation_, sparsity) != kTfLiteOk) {
652         error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
653                                 i);
654         status = kTfLiteError;
655       }
656     } else {
657       if (subgraph->SetTensorParametersReadWrite(
658               i, type, get_name(tensor), dims, quantization, is_variable,
659               dims_signature) != kTfLiteOk) {
660         error_reporter_->Report("Tensor %d is invalidly specified in schema.\n",
661                                 i);
662         status = kTfLiteError;
663       }
664     }
665   }
666 
667   return status;
668 }
669 
ApplyDelegates(Interpreter * interpreter)670 TfLiteStatus InterpreterBuilder::ApplyDelegates(Interpreter* interpreter) {
671   // Apply Flex delegate if applicable.
672   if (has_flex_op_) {
673     if (Interpreter::TfLiteDelegatePtr flex_delegate = AcquireFlexDelegate()) {
674       TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegateImpl(
675           // Transfers ownership of flex_delegate to the interpreter.
676           std::move(flex_delegate)));
677     }
678   }
679   for (TfLiteDelegate* delegate : delegates_) {
680     // Note that we DON'T transfer ownership of the delegate to the interpreter.
681     // (Doing that would cause problems if operator() was invoked twice.)
682     TF_LITE_ENSURE_STATUS(interpreter->ModifyGraphWithDelegateImpl(delegate));
683   }
684   return kTfLiteOk;
685 }
686 
SetNumThreads(int num_threads)687 TfLiteStatus InterpreterBuilder::SetNumThreads(int num_threads) {
688   if (num_threads < -1) {
689     error_reporter_->Report(
690         "num_threads should be >= 0 or just -1 to let TFLite runtime set the "
691         "value.");
692     return kTfLiteError;
693   }
694   num_threads_ = num_threads;
695   return kTfLiteOk;
696 }
697 
operator ()(std::unique_ptr<Interpreter> * interpreter,int num_threads)698 TfLiteStatus InterpreterBuilder::operator()(
699     std::unique_ptr<Interpreter>* interpreter, int num_threads) {
700   TfLiteStatus status = SetNumThreads(num_threads);
701   if (status != kTfLiteOk) {
702     interpreter->reset();
703     return status;
704   }
705   return (*this)(interpreter);
706 }
707 
operator ()(std::unique_ptr<Interpreter> * interpreter)708 TfLiteStatus InterpreterBuilder::operator()(
709     std::unique_ptr<Interpreter>* interpreter) {
710   if (!interpreter) {
711     error_reporter_->Report(
712         "Null output pointer passed to InterpreterBuilder.");
713     return kTfLiteError;
714   }
715 
716   // Safe exit by deleting partially created interpreter, to reduce verbosity
717   // on error conditions. Use by return cleanup_on_error();
718   auto cleanup_and_error = [&interpreter]() {
719     interpreter->reset();
720     return kTfLiteError;
721   };
722 
723   if (!model_) {
724     error_reporter_->Report("Null pointer passed in as model.");
725     return cleanup_and_error();
726   }
727 
728   if (model_->version() != TFLITE_SCHEMA_VERSION) {
729     error_reporter_->Report(
730         "Model provided is schema version %d not equal "
731         "to supported version %d.\n",
732         model_->version(), TFLITE_SCHEMA_VERSION);
733     return cleanup_and_error();
734   }
735 
736   if (BuildLocalIndexToRegistrationMapping() != kTfLiteOk) {
737     error_reporter_->Report("Registration failed.\n");
738     return cleanup_and_error();
739   }
740 
741   // Flatbuffer model schemas define a list of opcodes independent of the graph.
742   // We first map those to registrations. This reduces string lookups for custom
743   // ops since we only do it once per custom op rather than once per custom op
744   // invocation in the model graph.
745   // Construct interpreter with correct number of tensors and operators.
746   auto* subgraphs = model_->subgraphs();
747   auto* buffers = model_->buffers();
748 
749   if (subgraphs->size() == 0) {
750     TF_LITE_REPORT_ERROR(error_reporter_, "No subgraph in the model.\n");
751     return cleanup_and_error();
752   }
753 
754   if (!buffers) {
755     TF_LITE_REPORT_ERROR(error_reporter_, "No buffers in the model.\n");
756     return cleanup_and_error();
757   }
758 
759   *interpreter = std::make_unique<Interpreter>(error_reporter_);
760   if (subgraphs->size() > 1) {
761     (*interpreter)->AddSubgraphs(subgraphs->size() - 1);
762   }
763 
764   // Set num threads after all the subgraphs are added.
765   (*interpreter)->SetNumThreads(num_threads_);
766 
767   // Set Interpreter options
768   (*interpreter)->ApplyOptionsImpl(&options_);
769 
770   (*interpreter)
771       ->SetProfilerImpl(tflite::profiling::MaybeCreatePlatformProfiler());
772 
773   for (int subgraph_index = 0; subgraph_index < subgraphs->size();
774        ++subgraph_index) {
775     const tflite::SubGraph* subgraph = (*subgraphs)[subgraph_index];
776     tflite::Subgraph* modified_subgraph =
777         (*interpreter)->subgraph(subgraph_index);
778     auto operators = subgraph->operators();
779     auto tensors = subgraph->tensors();
780     if (!tensors) {
781       TF_LITE_REPORT_ERROR(error_reporter_,
782                            "Did not get tensors in subgraph %d.\n",
783                            subgraph_index);
784       return cleanup_and_error();
785     }
786     if (modified_subgraph->AddTensors(tensors->size()) != kTfLiteOk) {
787       return cleanup_and_error();
788     }
789     // Parse inputs/outputs
790     modified_subgraph->SetInputs(
791         FlatBufferIntArrayToVector(subgraph->inputs()));
792     modified_subgraph->SetOutputs(
793         FlatBufferIntArrayToVector(subgraph->outputs()));
794 
795     // Finally setup nodes and tensors
796     // Parse tensors before nodes as ParseNodes checks input tensors for the
797     // nodes.
798     if (ParseTensors(buffers, tensors, modified_subgraph) != kTfLiteOk)
799       return cleanup_and_error();
800     if (operators && ParseNodes(operators, modified_subgraph) != kTfLiteOk)
801       return cleanup_and_error();
802 
803     std::vector<int> variables;
804     for (int i = 0; i < modified_subgraph->tensors_size(); ++i) {
805       auto* tensor = modified_subgraph->tensor(i);
806       if (tensor->is_variable) {
807         variables.push_back(i);
808       }
809     }
810     modified_subgraph->SetVariables(std::move(variables));
811     if (subgraph->name()) {
812       modified_subgraph->SetName(subgraph->name()->c_str());
813     }
814   }
815 
816   if (ParseSignatureDefs(model_->signature_defs(), interpreter->get()) !=
817       kTfLiteOk) {
818     return cleanup_and_error();
819   }
820 
821   if ((*interpreter)->SetMetadata(metadata_) != kTfLiteOk) {
822     return cleanup_and_error();
823   }
824 
825   if (ShouldCreateLazyDelegateProviders(num_fp32_tensors_)) {
826     (*interpreter)->lazy_delegate_providers_ =
827         op_resolver_.GetDelegateCreators();
828   }
829 
830   TfLiteStatus status = ApplyDelegates(interpreter->get());
831   if (status != kTfLiteOk) {
832     interpreter->reset();
833   }
834   return status;
835 }
836 
AddDelegate(TfLiteDelegate * delegate)837 void InterpreterBuilder::AddDelegate(TfLiteDelegate* delegate) {
838   if (delegate == nullptr) {
839     TF_LITE_REPORT_ERROR(error_reporter_, "Null delegate.");
840   } else {
841     delegates_.push_back(delegate);
842   }
843 }
844 
845 }  // namespace tflite
846