xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tflite_wrapper.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2021 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #ifndef FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
17*14675a02SAndroid Build Coastguard Worker #define FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <functional>
20*14675a02SAndroid Build Coastguard Worker #include <string>
21*14675a02SAndroid Build Coastguard Worker #include <utility>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
25*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/caching_error_reporter.h"
26*14675a02SAndroid Build Coastguard Worker #include "fcp/client/interruptible_runner.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/client/log_manager.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/client/simple_task_environment.h"
29*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/delegates/flex/delegate.h"
30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter.h"
31*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/model_builder.h"
32*14675a02SAndroid Build Coastguard Worker 
33*14675a02SAndroid Build Coastguard Worker namespace fcp {
34*14675a02SAndroid Build Coastguard Worker namespace client {
35*14675a02SAndroid Build Coastguard Worker namespace engine {
36*14675a02SAndroid Build Coastguard Worker 
37*14675a02SAndroid Build Coastguard Worker struct OutputTensors {
38*14675a02SAndroid Build Coastguard Worker   std::vector<std::string> output_tensor_names;
39*14675a02SAndroid Build Coastguard Worker   std::vector<tensorflow::Tensor> output_tensors;
40*14675a02SAndroid Build Coastguard Worker };
41*14675a02SAndroid Build Coastguard Worker 
42*14675a02SAndroid Build Coastguard Worker // Options for TFLite interpreter.
43*14675a02SAndroid Build Coastguard Worker struct TfLiteInterpreterOptions {
44*14675a02SAndroid Build Coastguard Worker   // When true, TFLite uses dynamic tensor allocation and release tensors that
45*14675a02SAndroid Build Coastguard Worker   // are no longer needed.
46*14675a02SAndroid Build Coastguard Worker   bool ensure_dynamic_tensors_are_released = false;
47*14675a02SAndroid Build Coastguard Worker   // When the threshold is zero, dynamic allocation is not enabled for any
48*14675a02SAndroid Build Coastguard Worker   // tensor.
49*14675a02SAndroid Build Coastguard Worker   int32_t large_tensor_threshold_for_dynamic_allocation = 0;
50*14675a02SAndroid Build Coastguard Worker   // Whether to disable the graph-reordering optimization that clusters delegate
51*14675a02SAndroid Build Coastguard Worker   // ops together.
52*14675a02SAndroid Build Coastguard Worker   bool disable_delegate_clustering = false;
53*14675a02SAndroid Build Coastguard Worker };
54*14675a02SAndroid Build Coastguard Worker 
55*14675a02SAndroid Build Coastguard Worker // A class to call into TFLite.
56*14675a02SAndroid Build Coastguard Worker // All functions in this interface indicate errors as follows:
57*14675a02SAndroid Build Coastguard Worker // - CANCELLED: interrupted execution
58*14675a02SAndroid Build Coastguard Worker // - INVALID_ARGUMENT:
59*14675a02SAndroid Build Coastguard Worker //    1. Invalid model.
60*14675a02SAndroid Build Coastguard Worker //    2. Initialization failure for TFLite required classes such as Interpreter,
61*14675a02SAndroid Build Coastguard Worker //    Delegate etc.
62*14675a02SAndroid Build Coastguard Worker //    3. Missing required inputs.
63*14675a02SAndroid Build Coastguard Worker //    4. TensorFlow error. The TensorFlow error messages are included in the
64*14675a02SAndroid Build Coastguard Worker //    Status message.
65*14675a02SAndroid Build Coastguard Worker // This class supports aborting ongoing calls, by polling the provided
66*14675a02SAndroid Build Coastguard Worker // should_abort function.
67*14675a02SAndroid Build Coastguard Worker // Parameters:
68*14675a02SAndroid Build Coastguard Worker //    1. model: The serialized TFLite model.
69*14675a02SAndroid Build Coastguard Worker //    2. should_abort: A function which will be polled periodically to determine
70*14675a02SAndroid Build Coastguard Worker //    if the computation should be aborted.
71*14675a02SAndroid Build Coastguard Worker //    3. timing_config: The TimingConfig for an InterruptibleRunner.
72*14675a02SAndroid Build Coastguard Worker //    4. log_manager: A LogManager.
73*14675a02SAndroid Build Coastguard Worker //    5. inputs: A hashmap which has input tensor name as key, tensor data as
74*14675a02SAndroid Build Coastguard Worker //    value.
75*14675a02SAndroid Build Coastguard Worker //    6. output_names: The names of the output tensors. The order for these
76*14675a02SAndroid Build Coastguard Worker //    tensor names must be deterministic.
77*14675a02SAndroid Build Coastguard Worker class TfLiteWrapper {
78*14675a02SAndroid Build Coastguard Worker  public:
79*14675a02SAndroid Build Coastguard Worker   static absl::StatusOr<std::unique_ptr<TfLiteWrapper>> Create(
80*14675a02SAndroid Build Coastguard Worker       const std::string& model, std::function<bool()> should_abort,
81*14675a02SAndroid Build Coastguard Worker       const InterruptibleRunner::TimingConfig& timing_config,
82*14675a02SAndroid Build Coastguard Worker       LogManager* log_manager,
83*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
84*14675a02SAndroid Build Coastguard Worker       std::vector<std::string> output_names,
85*14675a02SAndroid Build Coastguard Worker       const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads);
86*14675a02SAndroid Build Coastguard Worker 
87*14675a02SAndroid Build Coastguard Worker   // Wrapper around TfLite's Interpreter::Invoke method.
88*14675a02SAndroid Build Coastguard Worker   // If the run succeeds, a vector of output tensors (empty if there's no
89*14675a02SAndroid Build Coastguard Worker   // output tensors), or CANCELLED if the training run was cancelled or
90*14675a02SAndroid Build Coastguard Worker   // INVALID_ARGUMENT for the rest of errors.
91*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<OutputTensors> Run();
92*14675a02SAndroid Build Coastguard Worker 
93*14675a02SAndroid Build Coastguard Worker  private:
TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model,std::unique_ptr<CachingErrorReporter> error_reporter,tflite::TfLiteDelegateUniquePtr delegate,std::unique_ptr<tflite::Interpreter> interpreter,std::unique_ptr<InterruptibleRunner> interruptible_runner,std::vector<std::string> output_names)94*14675a02SAndroid Build Coastguard Worker   TfLiteWrapper(std::unique_ptr<tflite::FlatBufferModel> model,
95*14675a02SAndroid Build Coastguard Worker                 std::unique_ptr<CachingErrorReporter> error_reporter,
96*14675a02SAndroid Build Coastguard Worker                 tflite::TfLiteDelegateUniquePtr delegate,
97*14675a02SAndroid Build Coastguard Worker                 std::unique_ptr<tflite::Interpreter> interpreter,
98*14675a02SAndroid Build Coastguard Worker                 std::unique_ptr<InterruptibleRunner> interruptible_runner,
99*14675a02SAndroid Build Coastguard Worker                 std::vector<std::string> output_names)
100*14675a02SAndroid Build Coastguard Worker       : model_(std::move(model)),
101*14675a02SAndroid Build Coastguard Worker         error_reporter_(std::move(error_reporter)),
102*14675a02SAndroid Build Coastguard Worker         delegate_(std::move(delegate)),
103*14675a02SAndroid Build Coastguard Worker         interpreter_(std::move(interpreter)),
104*14675a02SAndroid Build Coastguard Worker         interruptible_runner_(std::move(interruptible_runner)),
105*14675a02SAndroid Build Coastguard Worker         output_names_(std::move(output_names)) {}
106*14675a02SAndroid Build Coastguard Worker   absl::Status ConvertTfLiteStatus(TfLiteStatus status);
107*14675a02SAndroid Build Coastguard Worker   absl::StatusOr<OutputTensors> ConstructOutputs();
108*14675a02SAndroid Build Coastguard Worker 
109*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<tflite::FlatBufferModel> model_;
110*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<CachingErrorReporter> error_reporter_;
111*14675a02SAndroid Build Coastguard Worker   tflite::TfLiteDelegateUniquePtr delegate_;
112*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<tflite::Interpreter> interpreter_;
113*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<InterruptibleRunner> interruptible_runner_;
114*14675a02SAndroid Build Coastguard Worker   const std::vector<std::string> output_names_;
115*14675a02SAndroid Build Coastguard Worker };
116*14675a02SAndroid Build Coastguard Worker 
117*14675a02SAndroid Build Coastguard Worker }  // namespace engine
118*14675a02SAndroid Build Coastguard Worker }  // namespace client
119*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
120*14675a02SAndroid Build Coastguard Worker 
121*14675a02SAndroid Build Coastguard Worker #endif  // FCP_CLIENT_ENGINE_TFLITE_WRAPPER_H_
122