1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 /// \file 16 /// Main abstraction controlling the tflite interpreter. 17 /// See c/common.h for the API for defining operations (TfLiteRegistration). 18 #ifndef TENSORFLOW_LITE_INTERPRETER_H_ 19 #define TENSORFLOW_LITE_INTERPRETER_H_ 20 21 #include <stddef.h> 22 #include <stdint.h> 23 24 #include <complex> 25 #include <cstdio> 26 #include <cstdlib> 27 #include <functional> 28 #include <map> 29 #include <memory> 30 #include <string> 31 #include <utility> 32 #include <vector> 33 34 #include "tensorflow/lite/allocation.h" 35 #include "tensorflow/lite/c/common.h" // IWYU pragma: export 36 #include "tensorflow/lite/core/api/error_reporter.h" 37 #include "tensorflow/lite/core/api/profiler.h" 38 #include "tensorflow/lite/core/subgraph.h" 39 #include "tensorflow/lite/experimental/resource/initialization_status.h" 40 #include "tensorflow/lite/experimental/resource/resource_base.h" 41 #include "tensorflow/lite/external_cpu_backend_context.h" 42 #include "tensorflow/lite/internal/signature_def.h" 43 #include "tensorflow/lite/interpreter_options.h" 44 #include "tensorflow/lite/portable_type_to_tflitetype.h" 45 #include "tensorflow/lite/profiling/root_profiler.h" 46 #include "tensorflow/lite/signature_runner.h" 47 #include "tensorflow/lite/stderr_reporter.h" 48 #include "tensorflow/lite/string_type.h" 49 #include "tensorflow/lite/type_to_tflitetype.h" 50 51 namespace tflite { 52 53 class InterpreterTest; // Class for friend declarations. 54 55 namespace delegates { 56 class InterpreterUtils; // Class for friend declarations. 57 58 namespace test_utils { 59 class TestDelegation; // Class for friend declarations. 60 } // namespace test_utils 61 } // namespace delegates 62 63 namespace interpreter_wrapper { 64 class InterpreterWrapper; // Class for friend declarations. 65 } // namespace interpreter_wrapper 66 67 /// An interpreter for a graph of nodes that input and output from tensors. 68 /// Each node of the graph processes a set of input tensors and produces a 69 /// set of output Tensors. All inputs/output tensors are referenced by index. 70 /// 71 /// Usage: 72 /// 73 /// <pre><code> 74 /// // Create model from file. Note that the model instance must outlive the 75 /// // interpreter instance. 76 /// auto model = tflite::FlatBufferModel::BuildFromFile(...); 77 /// if (model == nullptr) { 78 /// // Return error. 79 /// } 80 /// // Create an Interpreter with an InterpreterBuilder. 81 /// std::unique_ptr<tflite::Interpreter> interpreter; 82 /// tflite::ops::builtin::BuiltinOpResolver resolver; 83 /// if (InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) { 84 /// // Return failure. 85 /// } 86 /// if (interpreter->AllocateTensors() != kTfLiteOk) { 87 /// // Return failure. 88 /// } 89 /// 90 /// auto input = interpreter->typed_tensor<float>(0); 91 /// for (int i = 0; i < input_size; i++) { 92 /// input[i] = ...; 93 // } 94 /// interpreter->Invoke(); 95 /// </code></pre> 96 /// 97 /// Note: For nearly all practical use cases, one should not directly construct 98 /// an Interpreter object, but rather use the InterpreterBuilder. 99 /// 100 /// WARNING: This class is *not* thread-safe. The client is responsible for 101 /// ensuring serialized interaction to avoid data races and undefined behavior. 102 class Interpreter { 103 public: 104 // Instantiate an interpreter. All errors associated with reading and 105 // processing this model will be forwarded to the error_reporter object. 106 // 107 // Note, if error_reporter is nullptr, then a default StderrReporter is 108 // used. Ownership of 'error_reporter' remains with the caller. 109 // WARNING: Use of this constructor outside of an InterpreterBuilder is not 110 // recommended. 111 explicit Interpreter(ErrorReporter* error_reporter = DefaultErrorReporter()); 112 113 ~Interpreter(); 114 115 // Interpreters are not copyable as they have non-trivial memory semantics. 116 Interpreter(const Interpreter&) = delete; 117 Interpreter& operator=(const Interpreter&) = delete; 118 119 // Functions to build interpreter 120 #ifndef DOXYGEN_SKIP 121 /// Provide a list of tensor indexes that are inputs to the model. 122 /// Each index is bound check and this modifies the consistent_ flag of the 123 /// interpreter. 124 TfLiteStatus SetInputs(std::vector<int> inputs); 125 126 /// Provide a list of tensor indexes that are outputs to the model 127 /// Each index is bound check and this modifies the consistent_ flag of the 128 /// interpreter. 129 TfLiteStatus SetOutputs(std::vector<int> outputs); 130 131 /// Provide a list of tensor indexes that are variable tensors. 132 /// Each index is bound check and this modifies the consistent_ flag of the 133 /// interpreter. 134 TfLiteStatus SetVariables(std::vector<int> variables); 135 136 /// Adds a node with the given parameters and returns the index of the new 137 /// node in `node_index` (optionally). Interpreter will take ownership of 138 /// `builtin_data` and destroy it with `free`. Ownership of 'init_data' 139 /// remains with the caller. 140 TfLiteStatus AddNodeWithParameters(const std::vector<int>& inputs, 141 const std::vector<int>& outputs, 142 const char* init_data, 143 size_t init_data_size, void* builtin_data, 144 const TfLiteRegistration* registration, 145 int* node_index = nullptr); 146 147 /// Adds `tensors_to_add` tensors, preserving pre-existing Tensor entries. 148 /// The value pointed to by `first_new_tensor_index` will be set to the 149 /// index of the first new tensor if `first_new_tensor_index` is non-null. 150 TfLiteStatus AddTensors(int tensors_to_add, 151 int* first_new_tensor_index = nullptr); 152 153 /// Set description of inputs/outputs/data/fptrs for node `node_index`. 154 /// This variant assumes an external buffer has been allocated of size 155 /// bytes. The lifetime of buffer must be ensured to be greater or equal 156 /// to Interpreter. 157 TfLiteStatus SetTensorParametersReadOnly( 158 int tensor_index, TfLiteType type, const char* name, 159 const std::vector<int>& dims, TfLiteQuantization quantization, 160 const char* buffer, size_t bytes, const Allocation* allocation = nullptr); 161 162 /// Legacy. Deprecated in favor of above. 163 inline TfLiteStatus SetTensorParametersReadOnly( 164 int tensor_index, TfLiteType type, const char* name, 165 const std::vector<int>& dims, TfLiteQuantizationParams quantization, 166 const char* buffer, size_t bytes, 167 const Allocation* allocation = nullptr) { 168 return SetTensorParametersReadOnly(tensor_index, type, name, dims.size(), 169 dims.data(), quantization, buffer, bytes, 170 allocation); 171 } 172 173 TfLiteStatus SetTensorParametersReadOnly( 174 int tensor_index, TfLiteType type, const char* name, const size_t rank, 175 const int* dims, TfLiteQuantizationParams quantization, 176 const char* buffer, size_t bytes, const Allocation* allocation = nullptr); 177 178 /// Set description of inputs/outputs/data/fptrs for node `node_index`. 179 /// This variant assumes an external buffer has been allocated of size 180 /// bytes. The lifetime of buffer must be ensured to be greater or equal 181 /// to Interpreter. 182 TfLiteStatus SetTensorParametersReadWrite(int tensor_index, TfLiteType type, 183 const char* name, 184 const std::vector<int>& dims, 185 TfLiteQuantization quantization, 186 bool is_variable = false); 187 188 /// Legacy. Deprecated in favor of above. 189 inline TfLiteStatus SetTensorParametersReadWrite( 190 int tensor_index, TfLiteType type, const char* name, 191 const std::vector<int>& dims, TfLiteQuantizationParams quantization, 192 bool is_variable = false, 193 const std::vector<int>* dims_signature = nullptr) { 194 size_t rank_dims_signature = 0; 195 const int* dims_signature_pointer = nullptr; 196 if (dims_signature) { 197 rank_dims_signature = dims_signature->size(); 198 dims_signature_pointer = dims_signature->data(); 199 } 200 return SetTensorParametersReadWrite( 201 tensor_index, type, name, dims.size(), dims.data(), quantization, 202 is_variable, rank_dims_signature, dims_signature_pointer); 203 } 204 TfLiteStatus SetTensorParametersReadWrite( 205 int tensor_index, TfLiteType type, const char* name, const size_t rank, 206 const int* dims, TfLiteQuantizationParams quantization, 207 bool is_variable = false, const size_t rank_dims_signature = 0, 208 const int* dims_signature = nullptr); 209 #endif // DOXYGEN_SKIP 210 // Functions to access tensor data 211 212 /// Read only access to list of inputs. inputs()213 const std::vector<int>& inputs() const { return primary_subgraph().inputs(); } 214 215 /// Return the name of a given input. The given index must be between 0 and 216 /// inputs().size(). GetInputName(int index)217 const char* GetInputName(int index) const { 218 return context_->tensors[inputs()[index]].name; 219 } 220 221 /// Read only access to list of outputs. outputs()222 const std::vector<int>& outputs() const { 223 return primary_subgraph().outputs(); 224 } 225 226 /// Read only access to list of variable tensors. variables()227 const std::vector<int>& variables() const { 228 return primary_subgraph().variables(); 229 } 230 231 /// Return the name of a given output. The given index must be between 0 and 232 /// outputs().size(). GetOutputName(int index)233 const char* GetOutputName(int index) const { 234 return context_->tensors[outputs()[index]].name; 235 } 236 237 /// Return the number of tensors in the model. tensors_size()238 size_t tensors_size() const { return context_->tensors_size; } 239 240 /// Return the number of ops in the model. nodes_size()241 size_t nodes_size() const { return primary_subgraph().nodes_size(); } 242 243 /// WARNING: Experimental interface, subject to change execution_plan()244 const std::vector<int>& execution_plan() const { 245 return primary_subgraph().execution_plan(); 246 } 247 248 /// Get a mutable tensor data structure. 249 // TODO(aselle): Create a safe ArrayHandle interface to avoid exposing this 250 // read/write access to structure tensor(int tensor_index)251 TfLiteTensor* tensor(int tensor_index) { 252 return primary_subgraph().tensor(tensor_index); 253 } 254 255 /// Get an immutable tensor data structure. tensor(int tensor_index)256 const TfLiteTensor* tensor(int tensor_index) const { 257 return primary_subgraph().tensor(tensor_index); 258 } 259 260 /// Returns a pointer to an operation and registration data structure if in 261 /// bounds from the primary subgraph(subgraph_[0]). node_and_registration(int node_index)262 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( 263 int node_index) const { 264 return primary_subgraph().node_and_registration(node_index); 265 } 266 267 /// Returns a pointer to an operation and registration data structure if in 268 /// bounds. node_and_registration(int subgraph_index,int node_index)269 const std::pair<TfLiteNode, TfLiteRegistration>* node_and_registration( 270 int subgraph_index, int node_index) const { 271 return subgraph(subgraph_index)->node_and_registration(node_index); 272 } 273 274 /// Perform a checked cast to the appropriate tensor type (mutable pointer 275 /// version). 276 template <class T> typed_tensor(int tensor_index)277 T* typed_tensor(int tensor_index) { 278 if (TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 279 if (tensor_ptr->type == typeToTfLiteType<T>()) { 280 return reinterpret_cast<T*>(tensor_ptr->data.raw); 281 } 282 } 283 return nullptr; 284 } 285 286 /// Perform a checked cast to the appropriate tensor type (immutable pointer 287 /// version). 288 template <class T> typed_tensor(int tensor_index)289 const T* typed_tensor(int tensor_index) const { 290 if (const TfLiteTensor* tensor_ptr = tensor(tensor_index)) { 291 if (tensor_ptr->type == typeToTfLiteType<T>()) { 292 return reinterpret_cast<const T*>(tensor_ptr->data.raw); 293 } 294 } 295 return nullptr; 296 } 297 298 /// WARNING: Experimental interface, subject to change 299 /// Returns list of all keys of different method signatures defined in the 300 /// model. 301 /// Note, pointers returned have lifetime same as the Interpreter object. signature_keys()302 std::vector<const std::string*> signature_keys() const { 303 std::vector<const std::string*> signature_keys; 304 signature_keys.reserve(signature_defs_.size()); 305 for (const auto& sig_def : signature_defs_) { 306 signature_keys.emplace_back(&sig_def.signature_key); 307 } 308 return signature_keys; 309 } 310 311 /// WARNING: Experimental interface, subject to change 312 /// Returns a pointer to the SignatureRunner instance to run the part of the 313 /// graph identified by a SignatureDef. The nullptr is returned if the given 314 /// signature key is not valid. 315 /// If you need to specify delegates, you have to do that before calling this 316 /// function. This function will additionally apply default delegates. Thus, 317 /// applying delegates after that might lead to undesirable behaviors. 318 /// Note, the pointed instance has lifetime same as the Interpreter object 319 /// and the SignatureRunner class is *not* thread-safe. 320 SignatureRunner* GetSignatureRunner(const char* signature_key); 321 322 /// WARNING: Experimental interface, subject to change 323 /// Return the subgraph index that corresponds to a SignatureDef, defined by 324 /// 'signature_key'. 325 /// If invalid name passed, -1 will be returned. GetSubgraphIndexFromSignature(const char * signature_key)326 int GetSubgraphIndexFromSignature(const char* signature_key) const { 327 for (const auto& signature : signature_defs_) { 328 if (signature.signature_key == signature_key) { 329 return signature.subgraph_index; 330 } 331 } 332 return -1; 333 } 334 335 /// WARNING: Experimental interface, subject to change 336 /// Returns the mapping of inputs to tensor index in the signature 337 /// specified through 'signature_key'. 338 /// If invalid name passed, an empty list will be returned. signature_inputs(const char * signature_key)339 const std::map<std::string, uint32_t>& signature_inputs( 340 const char* signature_key) const { 341 for (const auto& sig_def : signature_defs_) { 342 if (sig_def.signature_key == signature_key) return sig_def.inputs; 343 } 344 static const std::map<std::string, uint32_t>* default_empty_list = 345 new std::map<std::string, uint32_t>(); 346 return *default_empty_list; 347 } 348 349 /// WARNING: Experimental interface, subject to change 350 /// Returns the mapping of outputs to tensor index in the signature 351 /// specified through 'signature_key'. 352 /// If invalid name passed, an empty list will be returned. signature_outputs(const char * signature_key)353 const std::map<std::string, uint32_t>& signature_outputs( 354 const char* signature_key) const { 355 for (const auto& sig_def : signature_defs_) { 356 if (sig_def.signature_key == signature_key) return sig_def.outputs; 357 } 358 static const std::map<std::string, uint32_t>* default_empty_list = 359 new std::map<std::string, uint32_t>(); 360 return *default_empty_list; 361 } 362 363 /// WARNING: Experimental interface, subject to change 364 /// Returns the input tensor identified by 'signature_input_name' in the 365 /// signature identified by 'signature_key'. 366 /// Returns nullptr if not found. input_tensor_by_signature(const char * signature_input_name,const char * signature_key)367 TfLiteTensor* input_tensor_by_signature(const char* signature_input_name, 368 const char* signature_key) { 369 const int subgraph_index = GetSubgraphIndexFromSignature(signature_key); 370 if (subgraph_index == -1) return nullptr; 371 const int tensor_index = GetTensorIndexFromSignature( 372 signature_input_name, signature_key, /*is_input=*/true); 373 if (tensor_index == -1) return nullptr; 374 return subgraph(subgraph_index)->tensor(tensor_index); 375 } 376 377 /// WARNING: Experimental interface, subject to change 378 /// Returns the output tensor identified by 'signature_output_name' in the 379 /// signature identified by 'signature_key'. 380 /// Returns nullptr if not found. output_tensor_by_signature(const char * signature_output_name,const char * signature_key)381 const TfLiteTensor* output_tensor_by_signature( 382 const char* signature_output_name, const char* signature_key) const { 383 const int subgraph_index = GetSubgraphIndexFromSignature(signature_key); 384 if (subgraph_index == -1) return nullptr; 385 const int tensor_index = GetTensorIndexFromSignature( 386 signature_output_name, signature_key, /*is_input=*/false); 387 if (tensor_index == -1) return nullptr; 388 return subgraph(subgraph_index)->tensor(tensor_index); 389 } 390 391 /// Return a mutable pointer to the given input tensor. The given index must 392 /// be between 0 and inputs().size(). input_tensor(size_t index)393 TfLiteTensor* input_tensor(size_t index) { return tensor(inputs()[index]); } 394 395 /// Return an immutable pointer to the given input tensor. The given index 396 /// must be between 0 and inputs().size(). input_tensor(size_t index)397 const TfLiteTensor* input_tensor(size_t index) const { 398 return tensor(inputs()[index]); 399 } 400 401 /// Return a mutable pointer into the data of a given input tensor. The given 402 /// index must be between 0 and inputs().size(). 403 template <class T> typed_input_tensor(int index)404 T* typed_input_tensor(int index) { 405 return typed_tensor<T>(inputs()[index]); 406 } 407 408 /// Return an immutable pointer into the data of a given input tensor. The 409 /// given index must be between 0 and inputs().size(). 410 template <class T> typed_input_tensor(int index)411 const T* typed_input_tensor(int index) const { 412 return typed_tensor<T>(inputs()[index]); 413 } 414 415 /// Return a mutable pointer to the given output tensor. The given index must 416 /// be between 0 and outputs().size(). output_tensor(size_t index)417 TfLiteTensor* output_tensor(size_t index) { return tensor(outputs()[index]); } 418 419 /// Return an immutable pointer to the given output tensor. The given index 420 /// must be between 0 and outputs().size(). output_tensor(size_t index)421 const TfLiteTensor* output_tensor(size_t index) const { 422 return tensor(outputs()[index]); 423 } 424 425 /// Return a mutable pointer into the data of a given output tensor. The given 426 /// index must be between 0 and outputs().size(). 427 template <class T> typed_output_tensor(int index)428 T* typed_output_tensor(int index) { 429 return typed_tensor<T>(outputs()[index]); 430 } 431 432 /// Return an immutable pointer into the data of a given output tensor. The 433 /// given index must be between 0 and outputs().size(). 434 template <class T> typed_output_tensor(int index)435 const T* typed_output_tensor(int index) const { 436 return typed_tensor<T>(outputs()[index]); 437 } 438 439 /// Change the dimensionality of a given tensor. Note, this is only acceptable 440 /// for tensor indices that are inputs or variables. 441 /// Returns status of failure or success. Note that this doesn't actually 442 /// resize any existing buffers. A call to AllocateTensors() is required to 443 /// change the tensor input buffer. 444 TfLiteStatus ResizeInputTensor(int tensor_index, 445 const std::vector<int>& dims); 446 447 /// Change the dimensionality of a given tensor. This is only acceptable for 448 /// tensor indices that are inputs or variables. Only unknown dimensions can 449 /// be resized with this function. Unknown dimensions are indicated as `-1` in 450 /// the `dims_signature` attribute of a `TfLiteTensor`. Returns status of 451 /// failure or success. Note that this doesn't actually resize any existing 452 /// buffers. A call to AllocateTensors() is required to change the tensor 453 /// input buffer. 454 TfLiteStatus ResizeInputTensorStrict(int tensor_index, 455 const std::vector<int>& dims); 456 457 /// This releases memory held by non-persistent tensors. It does NOT 458 /// re-perform memory planning. AllocateTensors needs to be called before next 459 /// invocation. WARNING: Experimental interface, subject to change 460 TfLiteStatus ReleaseNonPersistentMemory(); 461 462 463 /// Update allocations for all tensors. This will redim dependent tensors 464 /// using the input tensor dimensionality as given. This is relatively 465 /// expensive. This *must be* called after the interpreter has been created 466 /// and before running inference (and accessing tensor buffers), and *must be* 467 /// called again if (and only if) an input tensor is resized. Returns status 468 /// of success or failure. Will fail if any of the ops in the model (other 469 /// than those which were rewritten by delegates, if any) are not supported by 470 /// the Interpreter's OpResolver. 471 TfLiteStatus AllocateTensors(); 472 473 /// Invoke the interpreter (run the whole graph in dependency order). 474 /// 475 /// NOTE: It is possible that the interpreter is not in a ready state 476 /// to evaluate (i.e. if a ResizeTensor() has been performed without an 477 /// AllocateTensors(). 478 /// Returns status of success or failure. 479 TfLiteStatus Invoke(); 480 481 /// Set the number of threads available to the interpreter. 482 /// 483 /// NOTE: `num_threads` should be >= -1. Setting `num_threads` to 0 has the 484 /// effect to disable multithreading, which is equivalent to setting 485 /// `num_threads` to 1. If set to the value -1, the number of threads used 486 /// will be implementation-defined and platform-dependent. 487 /// 488 /// As TfLite interpreter could internally apply a TfLite delegate by default 489 /// (i.e. XNNPACK), the number of threads that are available to the default 490 /// delegate *should be* set via InterpreterBuilder APIs as follows: 491 /// 492 /// std::unique_ptr<tflite::Interpreter> interpreter; 493 /// tflite::InterpreterBuilder builder(tflite model, op resolver); 494 /// builder.SetNumThreads(...) 495 /// ASSERT_EQ(builder(&interpreter), kTfLiteOk); 496 /// 497 /// WARNING: This API is deprecated: prefer using 498 /// `InterpreterBuilder::SetNumThreads`, as documented above. 499 TfLiteStatus SetNumThreads(int num_threads); 500 501 /// Allow float16 precision for FP32 calculation when possible. 502 /// Default: not allow. 503 /// 504 /// WARNING: This API is deprecated: prefer controlling this via delegate 505 /// options, e.g. `tflite::StatefulNnApiDelegate::Options::allow_fp16' or 506 /// `TfLiteGpuDelegateOptionsV2::is_precision_loss_allowed`. 507 /// This method will be removed in a future release. 508 void SetAllowFp16PrecisionForFp32(bool allow); 509 510 /// Get the half precision flag. 511 /// WARNING: This is an experimental API and subject to change. GetAllowFp16PrecisionForFp32()512 bool GetAllowFp16PrecisionForFp32() const { 513 return context_->allow_fp32_relax_to_fp16; 514 } 515 516 /// Sets the cancellation function pointer in order to cancel a request in the 517 /// middle of a call to Invoke(). The interpreter queries this function during 518 /// inference, between op invocations; when it returns true, the interpreter 519 /// will abort execution and return `kTfLiteError`. The `data` parameter 520 /// contains any data used by the cancellation function, and if non-null, 521 /// remains owned by the caller. 522 /// WARNING: This is an experimental API and subject to change. 523 void SetCancellationFunction(void* data, bool (*check_cancelled_func)(void*)); 524 525 /// Allow a delegate to look at the graph and modify the graph to handle 526 /// parts of the graph themselves. After this is called, the graph may 527 /// contain new nodes that replace 1 more nodes. 528 /// 'delegate' must outlive the interpreter. 529 /// Returns one of the following status codes: 530 /// 1. kTfLiteOk: Success. 531 /// 2. kTfLiteDelegateError: Delegation failed due to an error in the 532 /// delegate, or the delegate parameter was null. The Interpreter has been 533 /// restored to its pre-delegation state. 534 /// NOTE: This undoes all delegates previously applied to the Interpreter. 535 /// 3. kTfLiteApplicationError : Delegation failed to be applied due to the 536 /// incompatibility with the TfLite runtime, e.g., the model graph is already 537 /// immutable when applying the delegate. However, the interpreter could still 538 /// be invoked. 539 /// 4. kTfLiteUnresolvedOps: Delegation failed because the model has an 540 /// operator that cannot be resolved. This can happen when the op is not 541 /// registered or built with the TF Lite framework. 542 /// 5. kTfLiteError: Unexpected/runtime failure. 543 /// WARNING: This is an experimental API and subject to change. 544 TfLiteStatus ModifyGraphWithDelegate(TfLiteDelegate* delegate); 545 546 // Owning handle to a TfLiteDelegate instance. 547 using TfLiteDelegatePtr = 548 std::unique_ptr<TfLiteDelegate, void (*)(TfLiteDelegate*)>; 549 550 /// Same as ModifyGraphWithDelegate except this interpreter takes 551 /// ownership of the provided delegate. 552 /// WARNING: This is an experimental API and subject to change. 553 template <typename Delegate, typename Deleter> ModifyGraphWithDelegate(std::unique_ptr<Delegate,Deleter> delegate)554 inline TfLiteStatus ModifyGraphWithDelegate( 555 std::unique_ptr<Delegate, Deleter> delegate) { 556 Deleter deleter = std::move(delegate.get_deleter()); 557 558 // Note that we retain ownership of the delegate even if graph modification 559 // fails, as delegate use will be in an indeterminate state at that point. 560 owned_delegates_.emplace_back( 561 delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) { 562 deleter( 563 static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>( 564 delegate_to_delete)); 565 }); 566 return ModifyGraphWithDelegate(owned_delegates_.back().get()); 567 } 568 569 /// This overload is *never* OK. TfLiteDelegate is a C structure, so it has no 570 /// virtual destructor. The default deleter of the unique_ptr does not know 571 /// how to delete C++ objects deriving from TfLiteDelegate. 572 TfLiteStatus ModifyGraphWithDelegate( 573 std::unique_ptr<TfLiteDelegate> delegate) = delete; 574 575 /// Ensure the data in `tensor.data` is readable. In case delegate is used, 576 /// it might require to copy the data from delegate buffer to raw memory. 577 /// WARNING: This is an experimental API and subject to change. EnsureTensorDataIsReadable(int tensor_index)578 TfLiteStatus EnsureTensorDataIsReadable(int tensor_index) { 579 return primary_subgraph().EnsureTensorDataIsReadable(tensor_index); 580 } 581 582 /// Set the delegate buffer handle to a tensor. It can be called in the 583 /// following cases: 584 /// 1. Set the buffer handle to a tensor that's not being written by a 585 /// delegate. For example, feeding an OpenGL texture as the input of the 586 /// inference graph. 587 /// 2. Set the buffer handle to a tensor that uses the same delegate. 588 /// For example, set an OpenGL texture as the output of inference, while 589 /// the node which produces output is an OpenGL delegate node. 590 /// WARNING: This is an experimental API and subject to change. 591 TfLiteStatus SetBufferHandle(int tensor_index, 592 TfLiteBufferHandle buffer_handle, 593 TfLiteDelegate* delegate); 594 595 /// Get the delegate buffer handle, and the delegate which can process the 596 /// buffer handle. 597 /// WARNING: This is an experimental API and subject to change. 598 TfLiteStatus GetBufferHandle(int tensor_index, 599 TfLiteBufferHandle* buffer_handle, 600 TfLiteDelegate** delegate); 601 602 /// Sets the profiler to tracing execution. The caller retains ownership 603 /// of the profiler and must ensure its validity. 604 /// Previously registered profilers will be unregistered. 605 /// If `profiler` is nullptr, all previously installed profilers will be 606 /// removed. 607 /// WARNING: This is an experimental API and subject to change. 608 void SetProfiler(Profiler* profiler); 609 610 /// Same as SetProfiler except this interpreter takes ownership 611 /// of the provided profiler. 612 /// Previously registered profilers will be unregistered. 613 /// If `profiler` is nullptr, all previously installed profilers will be 614 /// removed. 615 /// WARNING: This is an experimental API and subject to change. 616 void SetProfiler(std::unique_ptr<Profiler> profiler); 617 618 /// Adds the profiler to tracing execution. The caller retains ownership 619 /// of the profiler and must ensure its validity. 620 /// nullptr `profiler` will be ignored. 621 /// WARNING: This is an experimental API and subject to change. 622 void AddProfiler(Profiler* profiler); 623 624 /// Gets the profiler used for op tracing. 625 /// WARNING: This is an experimental API and subject to change. 626 Profiler* GetProfiler(); 627 628 // The default capacity of `tensors_` vector. 629 static constexpr int kTensorsReservedCapacity = 128; 630 /// The capacity headroom of `tensors_` vector before calling ops' 631 /// `prepare` and `invoke` function. In these functions, it's guaranteed 632 /// allocating up to `kTensorsCapacityHeadroom` more tensors won't invalidate 633 /// pointers to existing tensors. 634 static constexpr int kTensorsCapacityHeadroom = 16; 635 636 /// Set if buffer handle output is allowed. 637 /// 638 /// When using hardware delegation, Interpreter will make the data of output 639 /// tensors available in `tensor->data` by default. If the application can 640 /// consume the buffer handle directly (e.g. reading output from OpenGL 641 /// texture), it can set this flag to false, so Interpreter won't copy the 642 /// data from buffer handle to CPU memory. 643 /// WARNING: This is an experimental API and subject to change. SetAllowBufferHandleOutput(bool allow_buffer_handle_output)644 void SetAllowBufferHandleOutput(bool allow_buffer_handle_output) { 645 allow_buffer_handle_output_ = allow_buffer_handle_output; 646 } 647 648 /// Reset all variable tensors to the default value. 649 /// If a variable tensor doesn't have a buffer, reset it to zero. 650 /// TODO(b/115961645): Implement - If a variable tensor has a buffer, reset it 651 /// to the value of the buffer. 652 /// WARNING: This is an experimental API and subject to change. 653 TfLiteStatus ResetVariableTensors(); 654 655 /// Retrieve an operator's description of its work, for profiling purposes. OpProfilingString(const TfLiteRegistration & op_reg,const TfLiteNode * node)656 const char* OpProfilingString(const TfLiteRegistration& op_reg, 657 const TfLiteNode* node) const { 658 if (op_reg.profiling_string == nullptr) return nullptr; 659 return op_reg.profiling_string(context_, node); 660 } 661 662 // Set the value of an external context. TFLite interpreter doesn't take the 663 // memory ownership of this external context 'ctx', and the context should 664 // outlive the TFLite interpreter. 665 void SetExternalContext(TfLiteExternalContextType type, 666 TfLiteExternalContext* ctx); 667 668 /// Assigns (or reassigns) a custom memory allocation for the given tensor. 669 /// `flags` is a bitmask, see TfLiteCustomAllocationFlags. 670 /// The runtime does NOT take ownership of the underlying memory. 671 /// 672 /// NOTE: User needs to call AllocateTensors() after this. 673 /// Invalid/insufficient buffers will cause an error during AllocateTensors or 674 /// Invoke (in case of dynamic shapes in the graph). 675 /// 676 /// Parameters should satisfy the following conditions: 677 /// 1. tensor->allocation_type == kTfLiteArenaRw or kTfLiteArenaRwPersistent 678 /// In general, this is true for I/O tensors & variable tensors. 679 /// 2. allocation->data has the appropriate permissions for runtime access 680 /// (Read-only for inputs, Read-Write for others), and outlives 681 /// Interpreter. 682 /// 3. allocation->bytes >= tensor->bytes. 683 /// This condition is checked again if any tensors are resized. 684 /// 4. allocation->data should be aligned to kDefaultTensorAlignment 685 /// defined in lite/util.h. (Currently 64 bytes) 686 /// This check is skipped if kTfLiteCustomAllocationFlagsSkipAlignCheck is 687 /// set through `flags`. 688 /// 689 /// WARNING: This is an experimental interface that is subject to change. 690 TfLiteStatus SetCustomAllocationForTensor( 691 int tensor_index, const TfLiteCustomAllocation& allocation, 692 int64_t flags = kTfLiteCustomAllocationFlagsNone); 693 694 /// Apply InterpreterOptions which tunes behavior of the interpreter. 695 /// WARNING: This is an experimental interface that is subject to change. 696 TfLiteStatus ApplyOptions(InterpreterOptions* options); 697 698 #ifndef DOXYGEN_SKIP 699 /// Return the number of subgraphs in the model. 700 /// WARNING: This is an experimental API and subject to change. subgraphs_size()701 size_t subgraphs_size() const { return subgraphs_.size(); } 702 703 /// Get a pointer to a subgraph if in bounds. 704 /// WARNING: This is an experimental API and subject to change. subgraph(int subgraph_index)705 const Subgraph* subgraph(int subgraph_index) const { 706 if (subgraph_index < 0 || 707 static_cast<size_t>(subgraph_index) >= subgraphs_size()) { 708 return nullptr; 709 } 710 return subgraphs_[subgraph_index].get(); 711 } 712 713 /// WARNING: This is an experimental API and subject to change. subgraph(int subgraph_index)714 Subgraph* subgraph(int subgraph_index) { 715 return const_cast<Subgraph*>( 716 static_cast<const Interpreter*>(this)->subgraph(subgraph_index)); 717 } 718 719 /// WARNING: Experimental interface, subject to change primary_subgraph()720 Subgraph& primary_subgraph() { 721 return *subgraphs_.front(); /// Safe as subgraphs_ always has 1 entry. 722 } 723 724 /// WARNING: Experimental interface, subject to change primary_subgraph()725 const Subgraph& primary_subgraph() const { 726 return *subgraphs_.front(); // Safe as subgraphs_ always has 1 entry. 727 } 728 729 /// WARNING: Experimental interface, subject to change 730 // Get the error reporter associated with this interpreter. error_reporter()731 ErrorReporter* error_reporter() const { return error_reporter_; } 732 733 #endif // DOXYGEN_SKIP 734 735 private: 736 friend class InterpreterBuilder; 737 friend class tflite::InterpreterTest; 738 friend class tflite::delegates::InterpreterUtils; 739 friend class tflite::delegates::test_utils::TestDelegation; 740 friend class tflite::interpreter_wrapper::InterpreterWrapper; 741 742 /// Set the value of an external context. 743 static void SetExternalContext(struct TfLiteContext* context, 744 TfLiteExternalContextType type, 745 TfLiteExternalContext* ctx); 746 747 // Helper method that return the tensor index that corresponds to 748 // a name in a SignatureDef. Defined by 'signature_key', and 749 // 'signature_tensor_name'. 750 // If 'is_input' is true then the tensor is checked in input tensors, 751 // otherwise it will be checked in output tensors. 752 // Returns -1 if the tensor is not found. GetTensorIndexFromSignature(const char * signature_tensor_name,const char * signature_key,bool is_input)753 int GetTensorIndexFromSignature(const char* signature_tensor_name, 754 const char* signature_key, 755 bool is_input) const { 756 // Iterate directly and don't use other methods to avoid extra allocation. 757 for (const auto& signature : signature_defs_) { 758 if (signature.signature_key != signature_key) continue; 759 auto& signature_list = (is_input ? signature.inputs : signature.outputs); 760 auto tensor_iter = signature_list.find(signature_tensor_name); 761 if (tensor_iter == signature_list.end()) return -1; 762 return tensor_iter->second; 763 } 764 return -1; 765 } 766 767 // Applies TFLite default delegates. 768 TfLiteStatus ApplyLazyDelegateProviders(); 769 770 // Private non-experimental implementation of ModifyGraphWithDelegate. 771 // Unlike ModifyGraphWithDelegate, ModifyGraphWithDelegateImpl is defined in 772 // interpreter.cc rather than in interpreter_experimental.cc, so it can be 773 // used to implement other non-experimental methods. 774 TfLiteStatus ModifyGraphWithDelegateImpl(TfLiteDelegate* delegate); 775 776 // Same as ModifyGraphWithDelegateImpl except that it takes ownership of the 777 // delegate. 778 template <typename Delegate, typename Deleter> ModifyGraphWithDelegateImpl(std::unique_ptr<Delegate,Deleter> && delegate)779 inline TfLiteStatus ModifyGraphWithDelegateImpl( 780 std::unique_ptr<Delegate, Deleter>&& delegate) { 781 Deleter deleter = std::move(delegate.get_deleter()); 782 783 // Note that we retain ownership of the delegate even if graph modification 784 // fails, as delegate use will be in an indeterminate state at that point. 785 owned_delegates_.emplace_back( 786 delegate.release(), [deleter](TfLiteDelegate* delegate_to_delete) { 787 deleter( 788 static_cast<typename std::unique_ptr<Delegate, Deleter>::pointer>( 789 delegate_to_delete)); 790 }); 791 return ModifyGraphWithDelegateImpl(owned_delegates_.back().get()); 792 } 793 794 // Overrides execution plan. ImplThis bounds checks indices sent in. 795 // Note: Only used during initialization. 796 TfLiteStatus SetExecutionPlan(const std::vector<int>& new_plan); 797 798 // Sets the profiler to all subgraphs. 799 void SetSubgraphProfiler(); 800 801 // Remove delegates (for fallback behaviour). The interpreter is invokable 802 // afterwards. 803 TfLiteStatus RemoveAllDelegates(); 804 805 // Returns true if delegates have been applied. 806 bool HasDelegates(); 807 808 // Returns true if the model has been fully delegated. 809 bool IsFullyDelegated() const; 810 811 // Returns true if cancellation function returns true. 812 bool IsCancelled(); 813 814 // Sets the list of signature defs in the model. SetSignatureDef(std::vector<internal::SignatureDef> signature_defs)815 void SetSignatureDef(std::vector<internal::SignatureDef> signature_defs) { 816 signature_defs_ = std::move(signature_defs); 817 } 818 819 // Sets model metadata as a mapping of name (key) and buffer (value) strings. 820 // Used by InterpreterBuilder, should be called after setting up subgraphs. 821 TfLiteStatus SetMetadata(const std::map<std::string, std::string>& metadata); 822 823 /// Adds `subgraphs_to_add` subgraphs, preserving pre-existing Subgraph 824 /// entries. The value pointed to by `first_new_subgraph_index` will be set to 825 /// the index of the first new subgraph if `first_new_subgraph_index` is 826 /// non-null. 827 void AddSubgraphs(int subgraphs_to_add, 828 int* first_new_subgraph_index = nullptr); 829 830 /// Implementation of SetProfiler. 831 /// Unlike SetProfiler, this is defined in interpreter.cc rather than in 832 /// intepreter_experimental.cc, so it can be used by interpreter_builder.cc. 833 void SetProfilerImpl(std::unique_ptr<Profiler> profiler); 834 835 TfLiteStatus ApplyOptionsImpl(InterpreterOptions* options); 836 837 // A pure C data structure used to communicate with the pure C plugin 838 // interface. To avoid copying tensor metadata, this is also the definitive 839 // structure to store tensors. 840 // This is the primary subgraph context. 841 TfLiteContext* context_ = nullptr; 842 843 // The error reporter delegate that tflite will forward queries errors to. 844 ErrorReporter* error_reporter_ = nullptr; 845 846 // List of delegates that have been installed and are owned by this 847 // interpreter instance. Useful if client delegate ownership is burdensome. 848 // WARNING: This is an experimental API and subject to change. 849 // TODO(b/116667551): Use TfLiteExternalContext for storing state. 850 std::vector< 851 std::unique_ptr<TfLiteDelegate, std::function<void(TfLiteDelegate*)>>> 852 owned_delegates_; 853 854 // A root profiler that holds a list of attached profiler implementations. 855 // will be nullptr if there's no child profiler registered. 856 std::unique_ptr<profiling::RootProfiler> root_profiler_; 857 858 bool allow_buffer_handle_output_ = false; 859 860 // List of active external contexts. 861 TfLiteExternalContext* external_contexts_[kTfLiteMaxExternalContexts]; 862 863 // The default external cpu backend context. After an TFLite interpreter is 864 // initialized, 'external_contexts_[kTfLiteCpuBackendContext]' is set to point 865 // to this object. However, if this element value is overwritten via calling 866 // 'SetExternalContext(kTfLiteCpuBackendContext, ...)', we will reset this to 867 // nullptr if necessary. 868 std::unique_ptr<ExternalCpuBackendContext> own_external_cpu_backend_context_; 869 870 // Subgraphs 871 std::vector<std::unique_ptr<Subgraph>> subgraphs_; 872 873 // A map of resources. Owned by interpreter and shared by multiple subgraphs. 874 resource::ResourceMap resources_; 875 876 // A map of resource Ids. Owned by interpreter and shared by multiple 877 // subgraphs. 878 resource::ResourceIDMap resource_ids_; 879 880 // A map of intialization statuses, that indicate whether the intialization 881 // subgraph invocation is done or not. Owned by interpreter and shared by 882 // multiple subgraphs. 883 resource::InitializationStatusMap initialization_status_map_; 884 885 // Indicating delegates that the TFLite interpreter will apply by default. 886 // An empty one means there's no delegate to be applied by default or 887 // delegates have been applied and doesn't need to be applied again. 888 using TfLiteDelegateCreator = 889 std::function<TfLiteDelegatePtr(int /*num_threads*/)>; 890 using TfLiteDelegateCreators = std::vector<TfLiteDelegateCreator>; 891 TfLiteDelegateCreators lazy_delegate_providers_; 892 893 // List of SignatureDefs obtained from the model. 894 std::vector<internal::SignatureDef> signature_defs_; 895 896 // Map of signature key to its corresponding SignatureRunner object. 897 // A SignatureRunner is basically a wrapper of the Subgraph corresponding to 898 // its SignatureDef. 899 std::map<std::string, SignatureRunner> signature_runner_map_; 900 901 // Model metadata stored as mapping of name (key) to buffer (value). 902 // Data is mapped from the Metadata in TFLite flatbuffer model. 903 std::map<std::string, std::string> metadata_; 904 905 // InterpreterOptions object which is being used. 906 std::unique_ptr<InterpreterOptions> options_; 907 }; 908 909 } // namespace tflite 910 #endif // TENSORFLOW_LITE_INTERPRETER_H_ 911