xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/interpreter_builder.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 /// \file
16 /// Provides functionality to construct an interpreter for a model.
17 ///
18 #ifndef TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
19 #define TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
20 
21 #include <map>
22 #include <memory>
23 #include <string>
24 #include <vector>
25 
26 #include "flatbuffers/flatbuffers.h"  // from @flatbuffers
27 #include "tensorflow/lite/allocation.h"
28 #include "tensorflow/lite/c/common.h"
29 #include "tensorflow/lite/core/api/error_reporter.h"
30 #include "tensorflow/lite/core/api/op_resolver.h"
31 #include "tensorflow/lite/core/subgraph.h"
32 #include "tensorflow/lite/interpreter.h"
33 #include "tensorflow/lite/model_builder.h"
34 #include "tensorflow/lite/mutable_op_resolver.h"
35 #include "tensorflow/lite/schema/schema_generated.h"
36 #include "tensorflow/lite/stderr_reporter.h"
37 
38 namespace tflite {
39 
40 /// Build an interpreter capable of interpreting `model`.
41 ///
42 /// `model`: A model whose lifetime must be at least as long as any
43 ///   interpreter(s) created by the builder. In principle multiple interpreters
44 ///   can be made from a single model.
45 /// `op_resolver`: An instance that implements the `OpResolver` interface, which
46 ///   maps custom op names and builtin op codes to op registrations. The
47 ///   lifetime of the provided `op_resolver` object must be at least as long as
48 ///   the `InterpreterBuilder`; unlike `model` and `error_reporter`, the
49 ///   `op_resolver` does not need to exist for the duration of any created
50 ///   `Interpreter` objects.
51 /// `error_reporter`: a functor that is called to report errors that handles
52 ///   printf var arg semantics. The lifetime of the `error_reporter` object must
53 ///   be greater than or equal to the `Interpreter` created by `operator()`.
54 /// `options_experimental`: Options that can change behavior of interpreter.
55 ///   WARNING: this parameter is an experimental API and is subject to change.
56 ///
57 /// Returns a kTfLiteOk when successful and sets interpreter to a valid
58 /// Interpreter. Note: The user must ensure the lifetime of the model (and error
59 /// reporter, if provided) is at least as long as interpreter's lifetime, and
60 /// a single model instance may safely be used with multiple interpreters.
61 class InterpreterBuilder {
62  public:
63   /// For this constructor, the ErrorReporter will be extracted from the
64   /// FlatBufferModel.
65   /// `options` object is copied during construction. So caller can release it
66   // after calling the constructor.
67   InterpreterBuilder(const FlatBufferModel& model,
68                      const OpResolver& op_resolver,
69                      const InterpreterOptions* options_experimental = nullptr);
70   /// Builds an interpreter given only the raw flatbuffer Model object (instead
71   /// of a FlatBufferModel). Mostly used for testing.
72   /// If `error_reporter` is null, then DefaultErrorReporter() is used.
73   /// `options` object is copied during construction. So caller can release it
74   // after calling the constructor.
75   InterpreterBuilder(const ::tflite::Model* model,
76                      const OpResolver& op_resolver,
77                      ErrorReporter* error_reporter = DefaultErrorReporter(),
78                      const InterpreterOptions* options_experimental = nullptr);
79   ~InterpreterBuilder();
80   InterpreterBuilder(const InterpreterBuilder&) = delete;
81   InterpreterBuilder& operator=(const InterpreterBuilder&) = delete;
82 
83   /// Builds an interpreter and stores it in `*interpreter`.
84   /// On success, returns kTfLiteOk and sets `*interpreter` to a valid
85   /// Interpreter.
86   /// On failure, returns an error status and sets `*interpreter` to nullptr.
87   TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter);
88 
89   /// Same as above, but also sets the number of CPU threads to use
90   /// (overriding any previous call to SetNumThreads).
91   /// Deprecated: use the SetNumThreads method instead.
92   TfLiteStatus operator()(std::unique_ptr<Interpreter>* interpreter,
93                           int num_threads);
94 
95   /// Sets the number of CPU threads to use for the interpreter.
96   /// Returns kTfLiteOk on success, kTfLiteError on error.
97   TfLiteStatus SetNumThreads(int num_threads);
98 
99   /// Any delegates added with AddDelegate will be applied to the Interpreter
100   /// generated by operator(), in the order that they were added.  (The delegate
101   /// parameter passed to AddDelegate should be non-null, otherwise an error
102   /// will be reported, and the call to AddDelegate will have no other effect.)
103   /// The lifetime of the delegate must be at least as long as the lifetime of
104   /// any Interpreter generated by this InterpreterBuilder.
105   void AddDelegate(TfLiteDelegate* delegate);
106 
107  private:
108   TfLiteStatus BuildLocalIndexToRegistrationMapping();
109   TfLiteStatus ParseNodes(
110       const flatbuffers::Vector<flatbuffers::Offset<Operator>>* operators,
111       Subgraph* subgraph);
112   TfLiteStatus ParseTensors(
113       const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers,
114       const flatbuffers::Vector<flatbuffers::Offset<Tensor>>* tensors,
115       Subgraph* subgraph);
116   TfLiteStatus ApplyDelegates(Interpreter* interpreter);
117   TfLiteStatus ParseQuantization(const QuantizationParameters* src_quantization,
118                                  TfLiteQuantization* quantization,
119                                  const std::vector<int>& dims);
120   TfLiteStatus ParseSparsity(const SparsityParameters* src_sparsity,
121                              TfLiteSparsity** sparsity);
122   TfLiteStatus ParseSignatureDefs(
123       const flatbuffers::Vector<flatbuffers::Offset<SignatureDef>>*
124           signature_def_list,
125       Interpreter* interpreter);
126 
127   const ::tflite::Model* model_;
128   const OpResolver& op_resolver_;
129   ErrorReporter* error_reporter_;
130   std::vector<TfLiteDelegate*> delegates_;
131   // Model metadata stored as mapping of name (key) to buffer (value).
132   // Data is mapped from the Metadata in TFLite flatbuffer model.
133   // TODO(b/188185962): Consider mapping to std::pair<const char*, size_t> if
134   // this increases runtime memory usage for large metadata.
135   std::map<std::string, std::string> metadata_;
136 
137   std::vector<const TfLiteRegistration*> flatbuffer_op_index_to_registration_;
138   std::vector<TfLiteRegistration> unresolved_custom_ops_;
139   std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
140   const Allocation* allocation_ = nullptr;
141 
142   bool has_flex_op_ = false;
143   int num_fp32_tensors_ = 0;
144   int num_threads_ = -1;
145   InterpreterOptions options_;
146 };
147 
148 }  // namespace tflite
149 
150 #endif  // TENSORFLOW_LITE_INTERPRETER_BUILDER_H_
151