1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 /// \file 16 /// Provides functionality to construct an interpreter for a model. 17 /// 18 #ifndef TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ 19 #define TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ 20 21 #include <map> 22 #include <memory> 23 #include <string> 24 #include <vector> 25 26 #include "flatbuffers/flatbuffers.h" // from @flatbuffers 27 #include "tensorflow/lite/allocation.h" 28 #include "tensorflow/lite/c/common.h" 29 #include "tensorflow/lite/core/api/error_reporter.h" 30 #include "tensorflow/lite/core/api/op_resolver.h" 31 #include "tensorflow/lite/core/subgraph.h" 32 #include "tensorflow/lite/interpreter.h" 33 #include "tensorflow/lite/model_builder.h" 34 #include "tensorflow/lite/mutable_op_resolver.h" 35 #include "tensorflow/lite/schema/schema_generated.h" 36 #include "tensorflow/lite/stderr_reporter.h" 37 38 namespace tflite { 39 40 /// Build an interpreter capable of interpreting `model`. 41 /// 42 /// `model`: A model whose lifetime must be at least as long as any 43 /// interpreter(s) created by the builder. In principle multiple interpreters 44 /// can be made from a single model. 45 /// `op_resolver`: An instance that implements the `OpResolver` interface, which 46 /// maps custom op names and builtin op codes to op registrations. The 47 /// lifetime of the provided `op_resolver` object must be at least as long as 48 /// the `InterpreterBuilder`; unlike `model` and `error_reporter`, the 49 /// `op_resolver` does not need to exist for the duration of any created 50 /// `Interpreter` objects. 51 /// `error_reporter`: a functor that is called to report errors that handles 52 /// printf var arg semantics. The lifetime of the `error_reporter` object must 53 /// be greater than or equal to the `Interpreter` created by `operator()`. 54 /// `options_experimental`: Options that can change behavior of interpreter. 55 /// WARNING: this parameter is an experimental API and is subject to change. 56 /// 57 /// Returns a kTfLiteOk when successful and sets interpreter to a valid 58 /// Interpreter. Note: The user must ensure the lifetime of the model (and error 59 /// reporter, if provided) is at least as long as interpreter's lifetime, and 60 /// a single model instance may safely be used with multiple interpreters. 61 class InterpreterBuilder { 62 public: 63 /// For this constructor, the ErrorReporter will be extracted from the 64 /// FlatBufferModel. 65 /// `options` object is copied during construction. So caller can release it 66 // after calling the constructor. 67 InterpreterBuilder(const FlatBufferModel& model, 68 const OpResolver& op_resolver, 69 const InterpreterOptions* options_experimental = nullptr); 70 /// Builds an interpreter given only the raw flatbuffer Model object (instead 71 /// of a FlatBufferModel). Mostly used for testing. 72 /// If `error_reporter` is null, then DefaultErrorReporter() is used. 73 /// `options` object is copied during construction. So caller can release it 74 // after calling the constructor. 75 InterpreterBuilder(const ::tflite::Model* model, 76 const OpResolver& op_resolver, 77 ErrorReporter* error_reporter = DefaultErrorReporter(), 78 const InterpreterOptions* options_experimental = nullptr); 79 ~InterpreterBuilder(); 80 InterpreterBuilder(const InterpreterBuilder&) = delete; 81 InterpreterBuilder& operator=(const InterpreterBuilder&) = delete; 82 83 /// Builds an interpreter and stores it in `*interpreter`. 84 /// On success, returns kTfLiteOk and sets `*interpreter` to a valid 85 /// Interpreter. 86 /// On failure, returns an error status and sets `*interpreter` to nullptr. 87 TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter); 88 89 /// Same as above, but also sets the number of CPU threads to use 90 /// (overriding any previous call to SetNumThreads). 91 /// Deprecated: use the SetNumThreads method instead. 92 TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter, 93 int num_threads); 94 95 /// Sets the number of CPU threads to use for the interpreter. 96 /// Returns kTfLiteOk on success, kTfLiteError on error. 97 TfLiteStatus SetNumThreads(int num_threads); 98 99 /// Any delegates added with AddDelegate will be applied to the Interpreter 100 /// generated by operator(), in the order that they were added. (The delegate 101 /// parameter passed to AddDelegate should be non-null, otherwise an error 102 /// will be reported, and the call to AddDelegate will have no other effect.) 103 /// The lifetime of the delegate must be at least as long as the lifetime of 104 /// any Interpreter generated by this InterpreterBuilder. 105 void AddDelegate(TfLiteDelegate* delegate); 106 107 private: 108 TfLiteStatus BuildLocalIndexToRegistrationMapping(); 109 TfLiteStatus ParseNodes( 110 const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators, 111 Subgraph* subgraph); 112 TfLiteStatus ParseTensors( 113 const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers, 114 const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors, 115 Subgraph* subgraph); 116 TfLiteStatus ApplyDelegates(Interpreter* interpreter); 117 TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization, 118 TfLiteQuantization* quantization, 119 const std::vector<int>& dims); 120 TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity, 121 TfLiteSparsity** sparsity); 122 TfLiteStatus ParseSignatureDefs( 123 const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>* 124 signature_def_list, 125 Interpreter* interpreter); 126 127 const ::tflite::Model* model_; 128 const OpResolver& op_resolver_; 129 ErrorReporter* error_reporter_; 130 std::vector<TfLiteDelegate*> delegates_; 131 // Model metadata stored as mapping of name (key) to buffer (value). 132 // Data is mapped from the Metadata in TFLite flatbuffer model. 133 // TODO(b/188185962): Consider mapping to std::pair<const char*, size_t> if 134 // this increases runtime memory usage for large metadata. 135 std::map<std::string, std::string> metadata_; 136 137 std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_; 138 std::vector<TfLiteRegistration> unresolved_custom_ops_; 139 std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_; 140 const Allocation* allocation_ = nullptr; 141 142 bool has_flex_op_ = false; 143 int num_fp32_tensors_ = 0; 144 int num_threads_ = -1; 145 InterpreterOptions options_; 146 }; 147 148 } // namespace tflite 149 150 #endif // TENSORFLOW_LITE_INTERPRETER_BUILDER_H_ 151