1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2021 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_ 17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_ 18*14675a02SAndroid Build Coastguard Worker 19*14675a02SAndroid Build Coastguard Worker #include <functional> 20*14675a02SAndroid Build Coastguard Worker #include <string> 21*14675a02SAndroid Build Coastguard Worker #include <utility> 22*14675a02SAndroid Build Coastguard Worker 23*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h" 24*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h" 25*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/caching_error_reporter.h" 26*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h" 27*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h" 28*14675a02SAndroid Build Coastguard Worker #include "fcp/client/simple_task_environment.h" 29*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/delegates/flex/delegate.h" 30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter.h" 31*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/model_builder.h" 32*14675a02SAndroid Build Coastguard Worker 33*14675a02SAndroid Build Coastguard Worker namespace fcp { 34*14675a02SAndroid Build Coastguard Worker namespace client { 35*14675a02SAndroid Build Coastguard Worker namespace engine { 36*14675a02SAndroid Build Coastguard Worker 37*14675a02SAndroid Build Coastguard Worker struct OutputTensors { 38*14675a02SAndroid Build Coastguard Worker std::vector<std::string> output_tensor_names; 39*14675a02SAndroid Build Coastguard Worker std::vector<tensorflow::Tensor> output_tensors; 40*14675a02SAndroid Build Coastguard Worker }; 41*14675a02SAndroid Build Coastguard Worker 42*14675a02SAndroid Build Coastguard Worker // Options for TFLite interpreter. 43*14675a02SAndroid Build Coastguard Worker struct TfLiteInterpreterOptions { 44*14675a02SAndroid Build Coastguard Worker // When true, TFLite uses dynamic tensor allocation and release tensors that 45*14675a02SAndroid Build Coastguard Worker // are no longer needed. 46*14675a02SAndroid Build Coastguard Worker bool ensure_dynamic_tensors_are_released = false; 47*14675a02SAndroid Build Coastguard Worker // When the threshold is zero, dynamic allocation is not enabled for any 48*14675a02SAndroid Build Coastguard Worker // tensor. 49*14675a02SAndroid Build Coastguard Worker int32_t large_tensor_threshold_for_dynamic_allocation = 0; 50*14675a02SAndroid Build Coastguard Worker // Whether to disable the graph-reordering optimization that clusters delegate 51*14675a02SAndroid Build Coastguard Worker // ops together. 52*14675a02SAndroid Build Coastguard Worker bool disable_delegate_clustering = false; 53*14675a02SAndroid Build Coastguard Worker }; 54*14675a02SAndroid Build Coastguard Worker 55*14675a02SAndroid Build Coastguard Worker // A class to call into TFLite. 56*14675a02SAndroid Build Coastguard Worker // All functions in this interface indicate errors as follows: 57*14675a02SAndroid Build Coastguard Worker // - CANCELLED: interrupted execution 58*14675a02SAndroid Build Coastguard Worker // - INVALID_ARGUMENT: 59*14675a02SAndroid Build Coastguard Worker // 1. Invalid model. 60*14675a02SAndroid Build Coastguard Worker // 2. Initialization failure for TFLite required classes such as Interpreter, 61*14675a02SAndroid Build Coastguard Worker // Delegate etc. 62*14675a02SAndroid Build Coastguard Worker // 3. Missing required inputs. 63*14675a02SAndroid Build Coastguard Worker // 4. TensorFlow error. The TensorFlow error messages are included in the 64*14675a02SAndroid Build Coastguard Worker // Status message. 65*14675a02SAndroid Build Coastguard Worker // This class supports aborting ongoing calls, by polling the provided 66*14675a02SAndroid Build Coastguard Worker // should_abort function. 67*14675a02SAndroid Build Coastguard Worker // Parameters: 68*14675a02SAndroid Build Coastguard Worker // 1. model: The serialized TFLite model. 69*14675a02SAndroid Build Coastguard Worker // 2. should_abort: A function which will be polled periodically to determine 70*14675a02SAndroid Build Coastguard Worker // if the computation should be aborted. 71*14675a02SAndroid Build Coastguard Worker // 3. timing_config: The TimingConfig for an InterruptibleRunner. 72*14675a02SAndroid Build Coastguard Worker // 4. log_manager: A LogManager. 73*14675a02SAndroid Build Coastguard Worker // 5. inputs: A hashmap which has input tensor name as key, tensor data as 74*14675a02SAndroid Build Coastguard Worker // value. 75*14675a02SAndroid Build Coastguard Worker // 6. output_names: The names of the output tensors. The order for these 76*14675a02SAndroid Build Coastguard Worker // tensor names must be deterministic. 77*14675a02SAndroid Build Coastguard Worker class TfLiteWrapper { 78*14675a02SAndroid Build Coastguard Worker public: 79*14675a02SAndroid Build Coastguard Worker static absl::StatusOr<std::unique_ptr<TfLiteWrapper>> Create( 80*14675a02SAndroid Build Coastguard Worker const std::string& model, std::function<bool()> should_abort, 81*14675a02SAndroid Build Coastguard Worker const InterruptibleRunner::TimingConfig& timing_config, 82*14675a02SAndroid Build Coastguard Worker LogManager* log_manager, 83*14675a02SAndroid Build Coastguard Worker std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs, 84*14675a02SAndroid Build Coastguard Worker std::vector<std::string> output_names, 85*14675a02SAndroid Build Coastguard Worker const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads); 86*14675a02SAndroid Build Coastguard Worker 87*14675a02SAndroid Build Coastguard Worker // Wrapper around TfLite's Interpreter::Invoke method. 88*14675a02SAndroid Build Coastguard Worker // If the run succeeds, a vector of output tensors (empty if there's no 89*14675a02SAndroid Build Coastguard Worker // output tensors), or CANCELLED if the training run was cancelled or 90*14675a02SAndroid Build Coastguard Worker // INVALID_ARGUMENT for the rest of errors. 91*14675a02SAndroid Build Coastguard Worker absl::StatusOr<OutputTensors> Run(); 92*14675a02SAndroid Build Coastguard Worker 93*14675a02SAndroid Build Coastguard Worker private: TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<CachingErrorReporter> error_reporter,tflite::TfLiteDelegateUniquePtr delegate,std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<InterruptibleRunner> interruptible_runner,std::vector<std::string> output_names)94*14675a02SAndroid Build Coastguard Worker TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model, 95*14675a02SAndroid Build Coastguard Worker std::unique_ptr<CachingErrorReporter> error_reporter, 96*14675a02SAndroid Build Coastguard Worker tflite::TfLiteDelegateUniquePtr delegate, 97*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tflite::Interpreter> interpreter, 98*14675a02SAndroid Build Coastguard Worker std::unique_ptr<InterruptibleRunner> interruptible_runner, 99*14675a02SAndroid Build Coastguard Worker std::vector<std::string> output_names) 100*14675a02SAndroid Build Coastguard Worker : model_(std::move(model)), 101*14675a02SAndroid Build Coastguard Worker error_reporter_(std::move(error_reporter)), 102*14675a02SAndroid Build Coastguard Worker delegate_(std::move(delegate)), 103*14675a02SAndroid Build Coastguard Worker interpreter_(std::move(interpreter)), 104*14675a02SAndroid Build Coastguard Worker interruptible_runner_(std::move(interruptible_runner)), 105*14675a02SAndroid Build Coastguard Worker output_names_(std::move(output_names)) {} 106*14675a02SAndroid Build Coastguard Worker absl::Status ConvertTfLiteStatus(TfLiteStatus status); 107*14675a02SAndroid Build Coastguard Worker absl::StatusOr<OutputTensors> ConstructOutputs(); 108*14675a02SAndroid Build Coastguard Worker 109*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tflite::FlatBufferModel> model_; 110*14675a02SAndroid Build Coastguard Worker std::unique_ptr<CachingErrorReporter> error_reporter_; 111*14675a02SAndroid Build Coastguard Worker tflite::TfLiteDelegateUniquePtr delegate_; 112*14675a02SAndroid Build Coastguard Worker std::unique_ptr<tflite::Interpreter> interpreter_; 113*14675a02SAndroid Build Coastguard Worker std::unique_ptr<InterruptibleRunner> interruptible_runner_; 114*14675a02SAndroid Build Coastguard Worker const std::vector<std::string> output_names_; 115*14675a02SAndroid Build Coastguard Worker }; 116*14675a02SAndroid Build Coastguard Worker 117*14675a02SAndroid Build Coastguard Worker } // namespace engine 118*14675a02SAndroid Build Coastguard Worker } // namespace client 119*14675a02SAndroid Build Coastguard Worker } // namespace fcp 120*14675a02SAndroid Build Coastguard Worker 121*14675a02SAndroid Build Coastguard Worker #endif // FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_ 122