xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/interpreter.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 /// \file
16 /// Main abstraction controlling the tflite interpreter.
17 /// See c/common.h for the API for defining operations (TfLiteRegistration).
18 #ifndef TENSORFLOW_LITE_INTERPRETER_H_
19 #define TENSORFLOW_LITE_INTERPRETER_H_
20 
21 #include <stddef.h>
22 #include <stdint.h>
23 
24 #include <complex>
25 #include <cstdio>
26 #include <cstdlib>
27 #include <functional>
28 #include <map>
29 #include <memory>
30 #include <string>
31 #include <utility>
32 #include <vector>
33 
34 #include "tensorflow/lite/allocation.h"
35 #include "tensorflow/lite/c/common.h"  // IWYU pragma: export
36 #include "tensorflow/lite/core/api/error_reporter.h"
37 #include "tensorflow/lite/core/api/profiler.h"
38 #include "tensorflow/lite/core/subgraph.h"
39 #include "tensorflow/lite/experimental/resource/initialization_status.h"
40 #include "tensorflow/lite/experimental/resource/resource_base.h"
41 #include "tensorflow/lite/external_cpu_backend_context.h"
42 #include "tensorflow/lite/internal/signature_def.h"
43 #include "tensorflow/lite/interpreter_options.h"
44 #include "tensorflow/lite/portable_type_to_tflitetype.h"
45 #include "tensorflow/lite/profiling/root_profiler.h"
46 #include "tensorflow/lite/signature_runner.h"
47 #include "tensorflow/lite/stderr_reporter.h"
48 #include "tensorflow/lite/string_type.h"
49 #include "tensorflow/lite/type_to_tflitetype.h"
50 
51 namespace tflite {
52 
53 class InterpreterTest;  // Class for friend declarations.
54 
55 namespace delegates {
56 class InterpreterUtils;  // Class for friend declarations.
57 
58 namespace test_utils {
59 class TestDelegation;  // Class for friend declarations.
60 }  // namespace test_utils
61 }  // namespace delegates
62 
63 namespace interpreter_wrapper {
64 class InterpreterWrapper;  // Class for friend declarations.
65 }  // namespace interpreter_wrapper
66 
67 /// An interpreter for a graph of nodes that input and output from tensors.
68 /// Each node of the graph processes a set of input tensors and produces a
69 /// set of output Tensors. All inputs/output tensors are referenced by index.
70 ///
71 /// Usage:
72 ///
73 /// <pre><code>
74 /// // Create model from file. Note that the model instance must outlive the
75 /// // interpreter instance.
76 /// auto model = tflite::FlatBufferModel::BuildFromFile(...);
77 /// if (model == nullptr) {
78 ///   // Return error.
79 /// }
80 /// // Create an Interpreter with an InterpreterBuilder.
81 /// std::unique_ptr<tflite::Interpreter> interpreter;
82 /// tflite::ops::builtin::BuiltinOpResolver resolver;
83 /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
84 ///   // Return failure.
85 /// }
86 /// if (interpreter->AllocateTensors() != kTfLiteOk) {
87 ///   // Return failure.
88 /// }
89 ///
90 /// auto input = interpreter->typed_tensor<float>(0);
91 /// for (int i = 0; i < input_size; i++) {
92 ///   input[i] = ...;
93 //  }
94 /// interpreter->Invoke();
95 /// </code></pre>
96 ///
97 /// Note: For nearly all practical use cases, one should not directly construct
98 /// an Interpreter object, but rather use the InterpreterBuilder.
99 ///
100 /// WARNING: This class is *not* thread-safe. The client is responsible for
101 /// ensuring serialized interaction to avoid data races and undefined behavior.
102 class Interpreter {
103  public:
104   // Instantiate an interpreter. All errors associated with reading and
105   // processing this model will be forwarded to the error_reporter object.
106   //
107   // Note, if error_reporter is nullptr, then a default StderrReporter is
108   // used. Ownership of 'error_reporter' remains with the caller.
109   // WARNING: Use of this constructor outside of an InterpreterBuilder is not
110   // recommended.
111   explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter());
112 
113   ~Interpreter();
114 
115   // Interpreters are not copyable as they have non-trivial memory semantics.
116   Interpreter(const Interpreter&) = delete;
117   Interpreter& operator=(const Interpreter&) = delete;
118 
119   // Functions to build interpreter
120 #ifndef DOXYGEN_SKIP
121   /// Provide a list of tensor indexes that are inputs to the model.
122   /// Each index is bound check and this modifies the consistent_ flag of the
123   /// interpreter.
124   TfLiteStatus SetInputs(std::vector<int> inputs);
125 
126   /// Provide a list of tensor indexes that are outputs to the model
127   /// Each index is bound check and this modifies the consistent_ flag of the
128   /// interpreter.
129   TfLiteStatus SetOutputs(std::vector<int> outputs);
130 
131   /// Provide a list of tensor indexes that are variable tensors.
132   /// Each index is bound check and this modifies the consistent_ flag of the
133   /// interpreter.
134   TfLiteStatus SetVariables(std::vector<int> variables);
135 
136   /// Adds a node with the given parameters and returns the index of the new
137   /// node in `node_index` (optionally). Interpreter will take ownership of
138   /// `builtin_data` and destroy it with `free`. Ownership of 'init_data'
139   /// remains with the caller.
140   TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs,
141                                      const std::vector<int>& outputs,
142                                      const char* init_data,
143                                      size_t init_data_size, void* builtin_data,
144                                      const TfLiteRegistration* registration,
145                                      int* node_index = nullptr);
146 
147   /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries.
148   /// The value pointed to by `first_new_tensor_index` will be set to the
149   /// index of the first new tensor if `first_new_tensor_index` is non-null.
150   TfLiteStatus AddTensors(int tensors_to_add,
151                           int* first_new_tensor_index = nullptr);
152 
153   /// Set description of inputs/outputs/data/fptrs for node `node_index`.
154   /// This variant assumes an external buffer has been allocated of size
155   /// bytes. The lifetime of buffer must be ensured to be greater or equal
156   /// to Interpreter.
157   TfLiteStatus SetTensorParametersReadOnly(
158       int tensor_index, TfLiteType type, const char* name,
159       const std::vector<int>& dims, TfLiteQuantization quantization,
160       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
161 
162   /// Legacy. Deprecated in favor of above.
163   inline TfLiteStatus SetTensorParametersReadOnly(
164       int tensor_index, TfLiteType type, const char* name,
165       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
166       const char* buffer, size_t bytes,
167       const Allocation* allocation = nullptr) {
168     return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(),
169                                        dims.data(), quantization, buffer, bytes,
170                                        allocation);
171   }
172 
173   TfLiteStatus SetTensorParametersReadOnly(
174       int tensor_index, TfLiteType type, const char* name, const size_t rank,
175       const int* dims, TfLiteQuantizationParams quantization,
176       const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
177 
178   /// Set description of inputs/outputs/data/fptrs for node `node_index`.
179   /// This variant assumes an external buffer has been allocated of size
180   /// bytes. The lifetime of buffer must be ensured to be greater or equal
181   /// to Interpreter.
182   TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type,
183                                             const char* name,
184                                             const std::vector<int>& dims,
185                                             TfLiteQuantization quantization,
186                                             bool is_variable = false);
187 
188   /// Legacy. Deprecated in favor of above.
189   inline TfLiteStatus SetTensorParametersReadWrite(
190       int tensor_index, TfLiteType type, const char* name,
191       const std::vector<int>& dims, TfLiteQuantizationParams quantization,
192       bool is_variable = false,
193       const std::vector<int>* dims_signature = nullptr) {
194     size_t rank_dims_signature = 0;
195     const int* dims_signature_pointer = nullptr;
196     if (dims_signature) {
197       rank_dims_signature = dims_signature->size();
198       dims_signature_pointer = dims_signature->data();
199     }
200     return SetTensorParametersReadWrite(
201         tensor_index, type, name, dims.size(), dims.data(), quantization,
202         is_variable, rank_dims_signature, dims_signature_pointer);
203   }
204   TfLiteStatus SetTensorParametersReadWrite(
205       int tensor_index, TfLiteType type, const char* name, const size_t rank,
206       const int* dims, TfLiteQuantizationParams quantization,
207       bool is_variable = false, const size_t rank_dims_signature = 0,
208       const int* dims_signature = nullptr);
209 #endif  // DOXYGEN_SKIP
210   // Functions to access tensor data
211 
212   /// Read only access to list of inputs.
inputs()213   const std::vector<int>& inputs() const { return primary_subgraph().inputs(); }
214 
215   /// Return the name of a given input. The given index must be between 0 and
216   /// inputs().size().
GetInputName(int index)217   const char* GetInputName(int index) const {
218     return context_->tensors[inputs()[index]].name;
219   }
220 
221   /// Read only access to list of outputs.
outputs()222   const std::vector<int>& outputs() const {
223     return primary_subgraph().outputs();
224   }
225 
226   /// Read only access to list of variable tensors.
variables()227   const std::vector<int>& variables() const {
228     return primary_subgraph().variables();
229   }
230 
231   /// Return the name of a given output. The given index must be between 0 and
232   /// outputs().size().
GetOutputName(int index)233   const char* GetOutputName(int index) const {
234     return context_->tensors[outputs()[index]].name;
235   }
236 
237   /// Return the number of tensors in the model.
tensors_size()238   size_t tensors_size() const { return context_->tensors_size; }
239 
240   /// Return the number of ops in the model.
nodes_size()241   size_t nodes_size() const { return primary_subgraph().nodes_size(); }
242 
243   /// WARNING: Experimental interface, subject to change
execution_plan()244   const std::vector<int>& execution_plan() const {
245     return primary_subgraph().execution_plan();
246   }
247 
248   /// Get a mutable tensor data structure.
249   // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this
250   // read/write access to structure
tensor(int tensor_index)251   TfLiteTensor* tensor(int tensor_index) {
252     return primary_subgraph().tensor(tensor_index);
253   }
254 
255   /// Get an immutable tensor data structure.
tensor(int tensor_index)256   const TfLiteTensor* tensor(int tensor_index) const {
257     return primary_subgraph().tensor(tensor_index);
258   }
259 
260   /// Returns a pointer to an operation and registration data structure if in
261   /// bounds from the primary subgraph(subgraph_[0]).
node_and_registration(int node_index)262   const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
263       int node_index) const {
264     return primary_subgraph().node_and_registration(node_index);
265   }
266 
267   /// Returns a pointer to an operation and registration data structure if in
268   /// bounds.
node_and_registration(int subgraph_index,int node_index)269   const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration(
270       int subgraph_index, int node_index) const {
271     return subgraph(subgraph_index)->node_and_registration(node_index);
272   }
273 
274   /// Perform a checked cast to the appropriate tensor type (mutable pointer
275   /// version).
276   template <class T>
typed_tensor(int tensor_index)277   T* typed_tensor(int tensor_index) {
278     if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
279       if (tensor_ptr->type == typeToTfLiteType<T>()) {
280         return reinterpret_cast<T*>(tensor_ptr->data.raw);
281       }
282     }
283     return nullptr;
284   }
285 
286   /// Perform a checked cast to the appropriate tensor type (immutable pointer
287   /// version).
288   template <class T>
typed_tensor(int tensor_index)289   const T* typed_tensor(int tensor_index) const {
290     if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) {
291       if (tensor_ptr->type == typeToTfLiteType<T>()) {
292         return reinterpret_cast<const T*>(tensor_ptr->data.raw);
293       }
294     }
295     return nullptr;
296   }
297 
298   /// WARNING: Experimental interface, subject to change
299   /// Returns list of all keys of different method signatures defined in the
300   /// model.
301   /// Note, pointers returned have lifetime same as the Interpreter object.
signature_keys()302   std::vector<const std::string*> signature_keys() const {
303     std::vector<const std::string*> signature_keys;
304     signature_keys.reserve(signature_defs_.size());
305     for (const auto& sig_def : signature_defs_) {
306       signature_keys.emplace_back(&sig_def.signature_key);
307     }
308     return signature_keys;
309   }
310 
311   /// WARNING: Experimental interface, subject to change
312   /// Returns a pointer to the SignatureRunner instance to run the part of the
313   /// graph identified by a SignatureDef. The nullptr is returned if the given
314   /// signature key is not valid.
315   /// If you need to specify delegates, you have to do that before calling this
316   /// function. This function will additionally apply default delegates. Thus,
317   /// applying delegates after that might lead to undesirable behaviors.
318   /// Note, the pointed instance has lifetime same as the Interpreter object
319   /// and the SignatureRunner class is *not* thread-safe.
320   SignatureRunner* GetSignatureRunner(const char* signature_key);
321 
322   /// WARNING: Experimental interface, subject to change
323   /// Return the subgraph index that corresponds to a SignatureDef, defined by
324   /// 'signature_key'.
325   /// If invalid name passed, -1 will be returned.
GetSubgraphIndexFromSignature(const char * signature_key)326   int GetSubgraphIndexFromSignature(const char* signature_key) const {
327     for (const auto& signature : signature_defs_) {
328       if (signature.signature_key == signature_key) {
329         return signature.subgraph_index;
330       }
331     }
332     return -1;
333   }
334 
335   /// WARNING: Experimental interface, subject to change
336   /// Returns the mapping of inputs to tensor index in the signature
337   /// specified through 'signature_key'.
338   /// If invalid name passed, an empty list will be returned.
signature_inputs(const char * signature_key)339   const std::map<std::string, uint32_t>& signature_inputs(
340       const char* signature_key) const {
341     for (const auto& sig_def : signature_defs_) {
342       if (sig_def.signature_key == signature_key) return sig_def.inputs;
343     }
344     static const std::map<std::string, uint32_t>* default_empty_list =
345         new std::map<std::string, uint32_t>();
346     return *default_empty_list;
347   }
348 
349   /// WARNING: Experimental interface, subject to change
350   /// Returns the mapping of outputs to tensor index in the signature
351   /// specified through 'signature_key'.
352   /// If invalid name passed, an empty list will be returned.
signature_outputs(const char * signature_key)353   const std::map<std::string, uint32_t>& signature_outputs(
354       const char* signature_key) const {
355     for (const auto& sig_def : signature_defs_) {
356       if (sig_def.signature_key == signature_key) return sig_def.outputs;
357     }
358     static const std::map<std::string, uint32_t>* default_empty_list =
359         new std::map<std::string, uint32_t>();
360     return *default_empty_list;
361   }
362 
363   /// WARNING: Experimental interface, subject to change
364   /// Returns the input tensor identified by 'signature_input_name' in the
365   /// signature identified by 'signature_key'.
366   /// Returns nullptr if not found.
input_tensor_by_signature(const char * signature_input_name,const char * signature_key)367   TfLiteTensor* input_tensor_by_signature(const char* signature_input_name,
368                                           const char* signature_key) {
369     const int subgraph_index = GetSubgraphIndexFromSignature(signature_key);
370     if (subgraph_index == -1) return nullptr;
371     const int tensor_index = GetTensorIndexFromSignature(
372         signature_input_name, signature_key, /*is_input=*/true);
373     if (tensor_index == -1) return nullptr;
374     return subgraph(subgraph_index)->tensor(tensor_index);
375   }
376 
377   /// WARNING: Experimental interface, subject to change
378   /// Returns the output tensor identified by 'signature_output_name' in the
379   /// signature identified by 'signature_key'.
380   /// Returns nullptr if not found.
output_tensor_by_signature(const char * signature_output_name,const char * signature_key)381   const TfLiteTensor* output_tensor_by_signature(
382       const char* signature_output_name, const char* signature_key) const {
383     const int subgraph_index = GetSubgraphIndexFromSignature(signature_key);
384     if (subgraph_index == -1) return nullptr;
385     const int tensor_index = GetTensorIndexFromSignature(
386         signature_output_name, signature_key, /*is_input=*/false);
387     if (tensor_index == -1) return nullptr;
388     return subgraph(subgraph_index)->tensor(tensor_index);
389   }
390 
391   /// Return a mutable pointer to the given input tensor. The given index must
392   /// be between 0 and inputs().size().
input_tensor(size_t index)393   TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); }
394 
395   /// Return an immutable pointer to the given input tensor. The given index
396   /// must be between 0 and inputs().size().
input_tensor(size_t index)397   const TfLiteTensor* input_tensor(size_t index) const {
398     return tensor(inputs()[index]);
399   }
400 
401   /// Return a mutable pointer into the data of a given input tensor. The given
402   /// index must be between 0 and inputs().size().
403   template <class T>
typed_input_tensor(int index)404   T* typed_input_tensor(int index) {
405     return typed_tensor<T>(inputs()[index]);
406   }
407 
408   /// Return an immutable pointer into the data of a given input tensor. The
409   /// given index must be between 0 and inputs().size().
410   template <class T>
typed_input_tensor(int index)411   const T* typed_input_tensor(int index) const {
412     return typed_tensor<T>(inputs()[index]);
413   }
414 
415   /// Return a mutable pointer to the given output tensor. The given index must
416   /// be between 0 and outputs().size().
output_tensor(size_t index)417   TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); }
418 
419   /// Return an immutable pointer to the given output tensor. The given index
420   /// must be between 0 and outputs().size().
output_tensor(size_t index)421   const TfLiteTensor* output_tensor(size_t index) const {
422     return tensor(outputs()[index]);
423   }
424 
425   /// Return a mutable pointer into the data of a given output tensor. The given
426   /// index must be between 0 and outputs().size().
427   template <class T>
typed_output_tensor(int index)428   T* typed_output_tensor(int index) {
429     return typed_tensor<T>(outputs()[index]);
430   }
431 
432   /// Return an immutable pointer into the data of a given output tensor. The
433   /// given index must be between 0 and outputs().size().
434   template <class T>
typed_output_tensor(int index)435   const T* typed_output_tensor(int index) const {
436     return typed_tensor<T>(outputs()[index]);
437   }
438 
439   /// Change the dimensionality of a given tensor. Note, this is only acceptable
440   /// for tensor indices that are inputs or variables.
441   /// Returns status of failure or success. Note that this doesn't actually
442   /// resize any existing buffers. A call to AllocateTensors() is required to
443   /// change the tensor input buffer.
444   TfLiteStatus ResizeInputTensor(int tensor_index,
445                                  const std::vector<int>& dims);
446 
447   /// Change the dimensionality of a given tensor. This is only acceptable for
448   /// tensor indices that are inputs or variables. Only unknown dimensions can
449   /// be resized with this function. Unknown dimensions are indicated as `-1` in
450   /// the `dims_signature` attribute of a `TfLiteTensor`. Returns status of
451   /// failure or success.  Note that this doesn't actually resize any existing
452   /// buffers. A call to AllocateTensors() is required to change the tensor
453   /// input buffer.
454   TfLiteStatus ResizeInputTensorStrict(int tensor_index,
455                                        const std::vector<int>& dims);
456 
457   /// This releases memory held by non-persistent tensors. It does NOT
458   /// re-perform memory planning. AllocateTensors needs to be called before next
459   /// invocation. WARNING: Experimental interface, subject to change
460   TfLiteStatus ReleaseNonPersistentMemory();
461 
462 
463   /// Update allocations for all tensors. This will redim dependent tensors
464   /// using the input tensor dimensionality as given. This is relatively
465   /// expensive. This *must be* called after the interpreter has been created
466   /// and before running inference (and accessing tensor buffers), and *must be*
467   /// called again if (and only if) an input tensor is resized. Returns status
468   /// of success or failure.  Will fail if any of the ops in the model (other
469   /// than those which were rewritten by delegates, if any) are not supported by
470   /// the Interpreter's OpResolver.
471   TfLiteStatus AllocateTensors();
472 
473   /// Invoke the interpreter (run the whole graph in dependency order).
474   ///
475   /// NOTE: It is possible that the interpreter is not in a ready state
476   /// to evaluate (i.e. if a ResizeTensor() has been performed without an
477   /// AllocateTensors().
478   /// Returns status of success or failure.
479   TfLiteStatus Invoke();
480 
481   /// Set the number of threads available to the interpreter.
482   ///
483   /// NOTE: `num_threads` should be >= -1. Setting `num_threads` to 0 has the
484   /// effect to disable multithreading, which is equivalent to setting
485   /// `num_threads` to 1. If set to the value -1, the number of threads used
486   /// will be implementation-defined and platform-dependent.
487   ///
488   /// As TfLite interpreter could internally apply a TfLite delegate by default
489   /// (i.e. XNNPACK), the number of threads that are available to the default
490   /// delegate *should be* set via InterpreterBuilder APIs as follows:
491   ///
492   ///     std::unique_ptr<tflite::Interpreter> interpreter;
493   ///     tflite::InterpreterBuilder builder(tflite model, op resolver);
494   ///     builder.SetNumThreads(...)
495   ///     ASSERT_EQ(builder(&interpreter), kTfLiteOk);
496   ///
497   /// WARNING: This API is deprecated: prefer using
498   /// `InterpreterBuilder::SetNumThreads`, as documented above.
499   TfLiteStatus SetNumThreads(int num_threads);
500 
501   /// Allow float16 precision for FP32 calculation when possible.
502   /// Default: not allow.
503   ///
504   /// WARNING: This API is deprecated: prefer controlling this via delegate
505   /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or
506   /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`.
507   /// This method will be removed in a future release.
508   void SetAllowFp16PrecisionForFp32(bool allow);
509 
510   /// Get the half precision flag.
511   /// WARNING: This is an experimental API and subject to change.
GetAllowFp16PrecisionForFp32()512   bool GetAllowFp16PrecisionForFp32() const {
513     return context_->allow_fp32_relax_to_fp16;
514   }
515 
516   /// Sets the cancellation function pointer in order to cancel a request in the
517   /// middle of a call to Invoke(). The interpreter queries this function during
518   /// inference, between op invocations; when it returns true, the interpreter
519   /// will abort execution and return `kTfLiteError`. The `data` parameter
520   /// contains any data used by the cancellation function, and if non-null,
521   /// remains owned by the caller.
522   /// WARNING: This is an experimental API and subject to change.
523   void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*));
524 
525   /// Allow a delegate to look at the graph and modify the graph to handle
526   /// parts of the graph themselves. After this is called, the graph may
527   /// contain new nodes that replace 1 more nodes.
528   /// 'delegate' must outlive the interpreter.
529   /// Returns one of the following status codes:
530   /// 1. kTfLiteOk: Success.
531   /// 2. kTfLiteDelegateError: Delegation failed due to an error in the
532   /// delegate, or the delegate parameter was null. The Interpreter has been
533   /// restored to its pre-delegation state.
534   /// NOTE: This undoes all delegates previously applied to the Interpreter.
535   /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the
536   /// incompatibility with the TfLite runtime, e.g., the model graph is already
537   /// immutable when applying the delegate. However, the interpreter could still
538   /// be invoked.
539   /// 4. kTfLiteUnresolvedOps: Delegation failed because the model has an
540   /// operator that cannot be resolved. This can happen when the op is not
541   /// registered or built with the TF Lite framework.
542   /// 5. kTfLiteError: Unexpected/runtime failure.
543   /// WARNING: This is an experimental API and subject to change.
544   TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate);
545 
546   // Owning handle to a TfLiteDelegate instance.
547   using TfLiteDelegatePtr =
548       std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>;
549 
550   /// Same as ModifyGraphWithDelegate except this interpreter takes
551   /// ownership of the provided delegate.
552   /// WARNING: This is an experimental API and subject to change.
553   template <typename Delegate, typename Deleter>
ModifyGraphWithDelegate(std::unique_ptr<Delegate,Deleter> delegate)554   inline TfLiteStatus ModifyGraphWithDelegate(
555       std::unique_ptr<Delegate, Deleter> delegate) {
556     Deleter deleter = std::move(delegate.get_deleter());
557 
558     // Note that we retain ownership of the delegate even if graph modification
559     // fails, as delegate use will be in an indeterminate state at that point.
560     owned_delegates_.emplace_back(
561         delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) {
562           deleter(
563               static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>(
564                   delegate_to_delete));
565         });
566     return ModifyGraphWithDelegate(owned_delegates_.back().get());
567   }
568 
569   /// This overload is *never* OK. TfLiteDelegate is a C structure, so it has no
570   /// virtual destructor. The default deleter of the unique_ptr does not know
571   /// how to delete C++ objects deriving from TfLiteDelegate.
572   TfLiteStatus ModifyGraphWithDelegate(
573       std::unique_ptr<TfLiteDelegate> delegate) = delete;
574 
575   /// Ensure the data in `tensor.data` is readable. In case delegate is used,
576   /// it might require to copy the data from delegate buffer to raw memory.
577   /// WARNING: This is an experimental API and subject to change.
EnsureTensorDataIsReadable(int tensor_index)578   TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) {
579     return primary_subgraph().EnsureTensorDataIsReadable(tensor_index);
580   }
581 
582   /// Set the delegate buffer handle to a tensor. It can be called in the
583   /// following cases:
584   /// 1. Set the buffer handle to a tensor that's not being written by a
585   ///    delegate. For example, feeding an OpenGL texture as the input of the
586   ///    inference graph.
587   /// 2. Set the buffer handle to a tensor that uses the same delegate.
588   ///    For example, set an OpenGL texture as the output of inference, while
589   ///    the node which produces output is an OpenGL delegate node.
590   /// WARNING: This is an experimental API and subject to change.
591   TfLiteStatus SetBufferHandle(int tensor_index,
592                                TfLiteBufferHandle buffer_handle,
593                                TfLiteDelegate* delegate);
594 
595   /// Get the delegate buffer handle, and the delegate which can process the
596   /// buffer handle.
597   /// WARNING: This is an experimental API and subject to change.
598   TfLiteStatus GetBufferHandle(int tensor_index,
599                                TfLiteBufferHandle* buffer_handle,
600                                TfLiteDelegate** delegate);
601 
602   /// Sets the profiler to tracing execution. The caller retains ownership
603   /// of the profiler and must ensure its validity.
604   /// Previously registered profilers will be unregistered.
605   /// If `profiler` is nullptr, all previously installed profilers will be
606   /// removed.
607   /// WARNING: This is an experimental API and subject to change.
608   void SetProfiler(Profiler* profiler);
609 
610   /// Same as SetProfiler except this interpreter takes ownership
611   /// of the provided profiler.
612   /// Previously registered profilers will be unregistered.
613   /// If `profiler` is nullptr, all previously installed profilers will be
614   /// removed.
615   /// WARNING: This is an experimental API and subject to change.
616   void SetProfiler(std::unique_ptr<Profiler> profiler);
617 
618   /// Adds the profiler to tracing execution. The caller retains ownership
619   /// of the profiler and must ensure its validity.
620   /// nullptr `profiler` will be ignored.
621   /// WARNING: This is an experimental API and subject to change.
622   void AddProfiler(Profiler* profiler);
623 
624   /// Gets the profiler used for op tracing.
625   /// WARNING: This is an experimental API and subject to change.
626   Profiler* GetProfiler();
627 
628   // The default capacity of `tensors_` vector.
629   static constexpr int kTensorsReservedCapacity = 128;
630   /// The capacity headroom of `tensors_` vector before calling ops'
631   /// `prepare` and `invoke` function. In these functions, it's guaranteed
632   /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate
633   /// pointers to existing tensors.
634   static constexpr int kTensorsCapacityHeadroom = 16;
635 
636   /// Set if buffer handle output is allowed.
637   ///
638   /// When using hardware delegation, Interpreter will make the data of output
639   /// tensors available in `tensor->data` by default. If the application can
640   /// consume the buffer handle directly (e.g. reading output from OpenGL
641   /// texture), it can set this flag to false, so Interpreter won't copy the
642   /// data from buffer handle to CPU memory.
643   /// WARNING: This is an experimental API and subject to change.
SetAllowBufferHandleOutput(bool allow_buffer_handle_output)644   void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) {
645     allow_buffer_handle_output_ = allow_buffer_handle_output;
646   }
647 
648   /// Reset all variable tensors to the default value.
649   /// If a variable tensor doesn't have a buffer, reset it to zero.
650   /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it
651   /// to the value of the buffer.
652   /// WARNING: This is an experimental API and subject to change.
653   TfLiteStatus ResetVariableTensors();
654 
655   /// Retrieve an operator's description of its work, for profiling purposes.
OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)656   const char* OpProfilingString(const TfLiteRegistration& op_reg,
657                                 const TfLiteNode* node) const {
658     if (op_reg.profiling_string == nullptr) return nullptr;
659     return op_reg.profiling_string(context_, node);
660   }
661 
662   // Set the value of an external context. TFLite interpreter doesn't take the
663   // memory ownership of this external context 'ctx', and the context should
664   // outlive the TFLite interpreter.
665   void SetExternalContext(TfLiteExternalContextType type,
666                           TfLiteExternalContext* ctx);
667 
668   /// Assigns (or reassigns) a custom memory allocation for the given tensor.
669   /// `flags` is a bitmask, see TfLiteCustomAllocationFlags.
670   /// The runtime does NOT take ownership of the underlying memory.
671   ///
672   /// NOTE: User needs to call AllocateTensors() after this.
673   /// Invalid/insufficient buffers will cause an error during AllocateTensors or
674   /// Invoke (in case of dynamic shapes in the graph).
675   ///
676   /// Parameters should satisfy the following conditions:
677   /// 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent
678   ///    In general, this is true for I/O tensors & variable tensors.
679   /// 2. allocation->data has the appropriate permissions for runtime access
680   ///    (Read-only for inputs, Read-Write for others), and outlives
681   ///    Interpreter.
682   /// 3. allocation->bytes >= tensor->bytes.
683   ///    This condition is checked again if any tensors are resized.
684   /// 4. allocation->data should be aligned to kDefaultTensorAlignment
685   ///    defined in lite/util.h. (Currently 64 bytes)
686   ///    This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is
687   ///    set through `flags`.
688   ///
689   /// WARNING: This is an experimental interface that is subject to change.
690   TfLiteStatus SetCustomAllocationForTensor(
691       int tensor_index, const TfLiteCustomAllocation& allocation,
692       int64_t flags = kTfLiteCustomAllocationFlagsNone);
693 
694   /// Apply InterpreterOptions which tunes behavior of the interpreter.
695   /// WARNING: This is an experimental interface that is subject to change.
696   TfLiteStatus ApplyOptions(InterpreterOptions* options);
697 
698 #ifndef DOXYGEN_SKIP
699   /// Return the number of subgraphs in the model.
700   /// WARNING: This is an experimental API and subject to change.
subgraphs_size()701   size_t subgraphs_size() const { return subgraphs_.size(); }
702 
703   /// Get a pointer to a subgraph if in bounds.
704   /// WARNING: This is an experimental API and subject to change.
subgraph(int subgraph_index)705   const Subgraph* subgraph(int subgraph_index) const {
706     if (subgraph_index < 0 ||
707         static_cast<size_t>(subgraph_index) >= subgraphs_size()) {
708       return nullptr;
709     }
710     return subgraphs_[subgraph_index].get();
711   }
712 
713   /// WARNING: This is an experimental API and subject to change.
subgraph(int subgraph_index)714   Subgraph* subgraph(int subgraph_index) {
715     return const_cast<Subgraph*>(
716         static_cast<const Interpreter*>(this)->subgraph(subgraph_index));
717   }
718 
719   /// WARNING: Experimental interface, subject to change
primary_subgraph()720   Subgraph& primary_subgraph() {
721     return *subgraphs_.front();  /// Safe as subgraphs_ always has 1 entry.
722   }
723 
724   /// WARNING: Experimental interface, subject to change
primary_subgraph()725   const Subgraph& primary_subgraph() const {
726     return *subgraphs_.front();  // Safe as subgraphs_ always has 1 entry.
727   }
728 
729   /// WARNING: Experimental interface, subject to change
730   // Get the error reporter associated with this interpreter.
error_reporter()731   ErrorReporter* error_reporter() const { return error_reporter_; }
732 
733 #endif  // DOXYGEN_SKIP
734 
735  private:
736   friend class InterpreterBuilder;
737   friend class tflite::InterpreterTest;
738   friend class tflite::delegates::InterpreterUtils;
739   friend class tflite::delegates::test_utils::TestDelegation;
740   friend class tflite::interpreter_wrapper::InterpreterWrapper;
741 
742   /// Set the value of an external context.
743   static void SetExternalContext(struct TfLiteContext* context,
744                                  TfLiteExternalContextType type,
745                                  TfLiteExternalContext* ctx);
746 
747   // Helper method that return the tensor index that corresponds to
748   // a name in a SignatureDef. Defined by 'signature_key', and
749   // 'signature_tensor_name'.
750   // If 'is_input' is true then the tensor is checked in input tensors,
751   // otherwise it will be checked in output tensors.
752   // Returns -1 if the tensor is not found.
GetTensorIndexFromSignature(const char * signature_tensor_name,const char * signature_key,bool is_input)753   int GetTensorIndexFromSignature(const char* signature_tensor_name,
754                                   const char* signature_key,
755                                   bool is_input) const {
756     // Iterate directly and don't use other methods to avoid extra allocation.
757     for (const auto& signature : signature_defs_) {
758       if (signature.signature_key != signature_key) continue;
759       auto& signature_list = (is_input ? signature.inputs : signature.outputs);
760       auto tensor_iter = signature_list.find(signature_tensor_name);
761       if (tensor_iter == signature_list.end()) return -1;
762       return tensor_iter->second;
763     }
764     return -1;
765   }
766 
767   // Applies TFLite default delegates.
768   TfLiteStatus ApplyLazyDelegateProviders();
769 
770   // Private non-experimental implementation of ModifyGraphWithDelegate.
771   // Unlike ModifyGraphWithDelegate, ModifyGraphWithDelegateImpl is defined in
772   // interpreter.cc rather than in interpreter_experimental.cc, so it can be
773   // used to implement other non-experimental methods.
774   TfLiteStatus ModifyGraphWithDelegateImpl(TfLiteDelegate* delegate);
775 
776   // Same as ModifyGraphWithDelegateImpl except that it takes ownership of the
777   // delegate.
778   template <typename Delegate, typename Deleter>
ModifyGraphWithDelegateImpl(std::unique_ptr<Delegate,Deleter> && delegate)779   inline TfLiteStatus ModifyGraphWithDelegateImpl(
780       std::unique_ptr<Delegate, Deleter>&& delegate) {
781     Deleter deleter = std::move(delegate.get_deleter());
782 
783     // Note that we retain ownership of the delegate even if graph modification
784     // fails, as delegate use will be in an indeterminate state at that point.
785     owned_delegates_.emplace_back(
786         delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) {
787           deleter(
788               static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>(
789                   delegate_to_delete));
790         });
791     return ModifyGraphWithDelegateImpl(owned_delegates_.back().get());
792   }
793 
794   // Overrides execution plan. ImplThis bounds checks indices sent in.
795   // Note: Only used during initialization.
796   TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan);
797 
798   // Sets the profiler to all subgraphs.
799   void SetSubgraphProfiler();
800 
801   // Remove delegates (for fallback behaviour). The interpreter is invokable
802   // afterwards.
803   TfLiteStatus RemoveAllDelegates();
804 
805   // Returns true if delegates have been applied.
806   bool HasDelegates();
807 
808   // Returns true if the model has been fully delegated.
809   bool IsFullyDelegated() const;
810 
811   // Returns true if cancellation function returns true.
812   bool IsCancelled();
813 
814   // Sets the list of signature defs in the model.
SetSignatureDef(std::vector<internal::SignatureDef> signature_defs)815   void SetSignatureDef(std::vector<internal::SignatureDef> signature_defs) {
816     signature_defs_ = std::move(signature_defs);
817   }
818 
819   // Sets model metadata as a mapping of name (key) and buffer (value) strings.
820   // Used by InterpreterBuilder, should be called after setting up subgraphs.
821   TfLiteStatus SetMetadata(const std::map<std::string, std::string>& metadata);
822 
823   /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph
824   /// entries. The value pointed to by `first_new_subgraph_index` will be set to
825   /// the index of the first new subgraph if `first_new_subgraph_index` is
826   /// non-null.
827   void AddSubgraphs(int subgraphs_to_add,
828                     int* first_new_subgraph_index = nullptr);
829 
830   /// Implementation of SetProfiler.
831   /// Unlike SetProfiler, this is defined in interpreter.cc rather than in
832   /// intepreter_experimental.cc, so it can be used by interpreter_builder.cc.
833   void SetProfilerImpl(std::unique_ptr<Profiler> profiler);
834 
835   TfLiteStatus ApplyOptionsImpl(InterpreterOptions* options);
836 
837   // A pure C data structure used to communicate with the pure C plugin
838   // interface. To avoid copying tensor metadata, this is also the definitive
839   // structure to store tensors.
840   // This is the primary subgraph context.
841   TfLiteContext* context_ = nullptr;
842 
843   // The error reporter delegate that tflite will forward queries errors to.
844   ErrorReporter* error_reporter_ = nullptr;
845 
846   // List of delegates that have been installed and are owned by this
847   // interpreter instance. Useful if client delegate ownership is burdensome.
848   // WARNING: This is an experimental API and subject to change.
849   // TODO(b/116667551): Use TfLiteExternalContext for storing state.
850   std::vector<
851       std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>>
852       owned_delegates_;
853 
854   // A root profiler that holds a list of attached profiler implementations.
855   // will be nullptr if there's no child profiler registered.
856   std::unique_ptr<profiling::RootProfiler> root_profiler_;
857 
858   bool allow_buffer_handle_output_ = false;
859 
860   // List of active external contexts.
861   TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts];
862 
863   // The default external cpu backend context. After an TFLite interpreter is
864   // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point
865   // to this object. However, if this element value is overwritten via calling
866   // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to
867   // nullptr if necessary.
868   std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_;
869 
870   // Subgraphs
871   std::vector<std::unique_ptr<Subgraph>> subgraphs_;
872 
873   // A map of resources. Owned by interpreter and shared by multiple subgraphs.
874   resource::ResourceMap resources_;
875 
876   // A map of resource Ids. Owned by interpreter and shared by multiple
877   // subgraphs.
878   resource::ResourceIDMap resource_ids_;
879 
880   // A map of intialization statuses, that indicate whether the intialization
881   // subgraph invocation is done or not. Owned by interpreter and shared by
882   // multiple subgraphs.
883   resource::InitializationStatusMap initialization_status_map_;
884 
885   // Indicating delegates that the TFLite interpreter will apply by default.
886   // An empty one means there's no delegate to be applied by default or
887   // delegates have been applied and doesn't need to be applied again.
888   using TfLiteDelegateCreator =
889       std::function<TfLiteDelegatePtr(int /*num_threads*/)>;
890   using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>;
891   TfLiteDelegateCreators lazy_delegate_providers_;
892 
893   // List of SignatureDefs obtained from the model.
894   std::vector<internal::SignatureDef> signature_defs_;
895 
896   // Map of signature key to its corresponding SignatureRunner object.
897   // A SignatureRunner is basically a wrapper of the Subgraph corresponding to
898   // its SignatureDef.
899   std::map<std::string, SignatureRunner> signature_runner_map_;
900 
901   // Model metadata stored as mapping of name (key) to buffer (value).
902   // Data is mapped from the Metadata in TFLite flatbuffer model.
903   std::map<std::string, std::string> metadata_;
904 
905   // InterpreterOptions object which is being used.
906   std::unique_ptr<InterpreterOptions> options_;
907 };
908 
909 }  // namespace tflite
910 #endif  // TENSORFLOW_LITE_INTERPRETER_H_
911