xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tflite_wrapper.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2021 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker #include "fcp/client/engine/tflite_wrapper.h"
17*14675a02SAndroid Build Coastguard Worker 
18*14675a02SAndroid Build Coastguard Worker #include <functional>
19*14675a02SAndroid Build Coastguard Worker #include <memory>
20*14675a02SAndroid Build Coastguard Worker #include <string>
21*14675a02SAndroid Build Coastguard Worker #include <utility>
22*14675a02SAndroid Build Coastguard Worker 
23*14675a02SAndroid Build Coastguard Worker #include "google/protobuf/any.pb.h"
24*14675a02SAndroid Build Coastguard Worker #include "absl/status/status.h"
25*14675a02SAndroid Build Coastguard Worker #include "absl/status/statusor.h"
26*14675a02SAndroid Build Coastguard Worker #include "absl/strings/str_format.h"
27*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
28*14675a02SAndroid Build Coastguard Worker #include "tensorflow/core/public/version.h"
29*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/delegates/flex/util.h"
30*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter.h"
31*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/interpreter_builder.h"
32*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/register.h"
33*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/model_builder.h"
34*14675a02SAndroid Build Coastguard Worker #include "tensorflow/lite/string_util.h"
35*14675a02SAndroid Build Coastguard Worker 
36*14675a02SAndroid Build Coastguard Worker namespace fcp {
37*14675a02SAndroid Build Coastguard Worker namespace client {
38*14675a02SAndroid Build Coastguard Worker namespace engine {
39*14675a02SAndroid Build Coastguard Worker 
40*14675a02SAndroid Build Coastguard Worker using ::tflite::ops::builtin::BuiltinOpResolver;
41*14675a02SAndroid Build Coastguard Worker 
42*14675a02SAndroid Build Coastguard Worker namespace {
43*14675a02SAndroid Build Coastguard Worker 
AssignStringInput(int index,const std::string & value,tflite::Interpreter * interpreter)44*14675a02SAndroid Build Coastguard Worker absl::Status AssignStringInput(int index, const std::string& value,
45*14675a02SAndroid Build Coastguard Worker                                tflite::Interpreter* interpreter) {
46*14675a02SAndroid Build Coastguard Worker   TfLiteTensor* tensor = interpreter->tensor(index);
47*14675a02SAndroid Build Coastguard Worker   if (tensor->type != kTfLiteString) {
48*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError("Input tensor is not a string tensor.");
49*14675a02SAndroid Build Coastguard Worker   }
50*14675a02SAndroid Build Coastguard Worker 
51*14675a02SAndroid Build Coastguard Worker   tflite::DynamicBuffer buf;
52*14675a02SAndroid Build Coastguard Worker   buf.AddString(value.data(), value.length());
53*14675a02SAndroid Build Coastguard Worker   buf.WriteToTensor(tensor, nullptr);
54*14675a02SAndroid Build Coastguard Worker   return absl::OkStatus();
55*14675a02SAndroid Build Coastguard Worker }
56*14675a02SAndroid Build Coastguard Worker 
57*14675a02SAndroid Build Coastguard Worker }  // anonymous namespace
58*14675a02SAndroid Build Coastguard Worker 
Create(const std::string & model,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,LogManager * log_manager,std::unique_ptr<absl::flat_hash_map<std::string,std::string>> inputs,std::vector<std::string> output_names,const TfLiteInterpreterOptions & interpreter_options,int32_t num_threads)59*14675a02SAndroid Build Coastguard Worker absl::StatusOr<std::unique_ptr<TfLiteWrapper>> TfLiteWrapper::Create(
60*14675a02SAndroid Build Coastguard Worker     const std::string& model, std::function<bool()> should_abort,
61*14675a02SAndroid Build Coastguard Worker     const InterruptibleRunner::TimingConfig& timing_config,
62*14675a02SAndroid Build Coastguard Worker     LogManager* log_manager,
63*14675a02SAndroid Build Coastguard Worker     std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
64*14675a02SAndroid Build Coastguard Worker     std::vector<std::string> output_names,
65*14675a02SAndroid Build Coastguard Worker     const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads) {
66*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<tflite::FlatBufferModel> flat_buffer_model =
67*14675a02SAndroid Build Coastguard Worker       tflite::FlatBufferModel::BuildFromBuffer(model.c_str(), model.size());
68*14675a02SAndroid Build Coastguard Worker   if (flat_buffer_model == nullptr) {
69*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError("Failed to build FlatBufferModel.");
70*14675a02SAndroid Build Coastguard Worker   }
71*14675a02SAndroid Build Coastguard Worker   // The training delegate needs to be created before the interpreter.
72*14675a02SAndroid Build Coastguard Worker   auto delegate = tflite::FlexDelegate::Create();
73*14675a02SAndroid Build Coastguard Worker   auto error_reporter = std::make_unique<CachingErrorReporter>();
74*14675a02SAndroid Build Coastguard Worker   auto interpreter = std::make_unique<tflite::Interpreter>();
75*14675a02SAndroid Build Coastguard Worker 
76*14675a02SAndroid Build Coastguard Worker   if (tflite::InterpreterBuilder(
77*14675a02SAndroid Build Coastguard Worker           flat_buffer_model->GetModel(), BuiltinOpResolver(),
78*14675a02SAndroid Build Coastguard Worker           error_reporter.get())(&interpreter) != kTfLiteOk) {
79*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
80*14675a02SAndroid Build Coastguard Worker         absl::StrCat("Failed to initiate interpreter: ",
81*14675a02SAndroid Build Coastguard Worker                      error_reporter->GetFirstErrorMessage()));
82*14675a02SAndroid Build Coastguard Worker   }
83*14675a02SAndroid Build Coastguard Worker   interpreter->SetNumThreads(num_threads);
84*14675a02SAndroid Build Coastguard Worker   if (interpreter->ModifyGraphWithDelegate(delegate.get()) != kTfLiteOk) {
85*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
86*14675a02SAndroid Build Coastguard Worker         absl::StrCat("Failed to modify graph with TrainingFlexDelegate: ",
87*14675a02SAndroid Build Coastguard Worker                      error_reporter->GetFirstErrorMessage()));
88*14675a02SAndroid Build Coastguard Worker   }
89*14675a02SAndroid Build Coastguard Worker   if (interpreter->AllocateTensors() != kTfLiteOk) {
90*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
91*14675a02SAndroid Build Coastguard Worker         absl::StrCat("Failed to allocate tensors: ",
92*14675a02SAndroid Build Coastguard Worker                      error_reporter->GetFirstErrorMessage()));
93*14675a02SAndroid Build Coastguard Worker   }
94*14675a02SAndroid Build Coastguard Worker   interpreter->SetCancellationFunction(delegate->data_,
95*14675a02SAndroid Build Coastguard Worker                                        tflite::FlexDelegate::HasCancelled);
96*14675a02SAndroid Build Coastguard Worker   for (const auto& input : interpreter->inputs()) {
97*14675a02SAndroid Build Coastguard Worker     std::string key = interpreter->GetInputName(input);
98*14675a02SAndroid Build Coastguard Worker     if (inputs->find(key) == inputs->end()) {
99*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("Unexpected input tensor.");
100*14675a02SAndroid Build Coastguard Worker     }
101*14675a02SAndroid Build Coastguard Worker     FCP_RETURN_IF_ERROR(
102*14675a02SAndroid Build Coastguard Worker         AssignStringInput(input, inputs->at(key), interpreter.get()));
103*14675a02SAndroid Build Coastguard Worker   }
104*14675a02SAndroid Build Coastguard Worker   // Create an InterruptibleRunner to execute TF calls in a background thread,
105*14675a02SAndroid Build Coastguard Worker   // allowing us to abort them if need be.
106*14675a02SAndroid Build Coastguard Worker   auto runner = std::make_unique<InterruptibleRunner>(
107*14675a02SAndroid Build Coastguard Worker       log_manager, should_abort, timing_config,
108*14675a02SAndroid Build Coastguard Worker       InterruptibleRunner::DiagnosticsConfig{
109*14675a02SAndroid Build Coastguard Worker           .interrupted =
110*14675a02SAndroid Build Coastguard Worker               ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
111*14675a02SAndroid Build Coastguard Worker           .interrupt_timeout = ProdDiagCode::
112*14675a02SAndroid Build Coastguard Worker               BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
113*14675a02SAndroid Build Coastguard Worker           .interrupted_extended = ProdDiagCode::
114*14675a02SAndroid Build Coastguard Worker               BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
115*14675a02SAndroid Build Coastguard Worker           .interrupt_timeout_extended = ProdDiagCode::
116*14675a02SAndroid Build Coastguard Worker               BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
117*14675a02SAndroid Build Coastguard Worker   return absl::WrapUnique(
118*14675a02SAndroid Build Coastguard Worker       new TfLiteWrapper(std::move(flat_buffer_model), std::move(error_reporter),
119*14675a02SAndroid Build Coastguard Worker                         std::move(delegate), std::move(interpreter),
120*14675a02SAndroid Build Coastguard Worker                         std::move(runner), std::move(output_names)));
121*14675a02SAndroid Build Coastguard Worker }
122*14675a02SAndroid Build Coastguard Worker 
Run()123*14675a02SAndroid Build Coastguard Worker absl::StatusOr<OutputTensors> TfLiteWrapper::Run() {
124*14675a02SAndroid Build Coastguard Worker   auto* interpreter_raw_pointer = interpreter_.get();
125*14675a02SAndroid Build Coastguard Worker   auto tflite_runnable = [interpreter_raw_pointer, this]() {
126*14675a02SAndroid Build Coastguard Worker     return ConvertTfLiteStatus(interpreter_raw_pointer->Invoke());
127*14675a02SAndroid Build Coastguard Worker   };
128*14675a02SAndroid Build Coastguard Worker   auto* delegate_raw_pointer =
129*14675a02SAndroid Build Coastguard Worker       static_cast<tflite::FlexDelegate*>(delegate_->data_);
130*14675a02SAndroid Build Coastguard Worker   auto abort_tflite = [delegate_raw_pointer]() {
131*14675a02SAndroid Build Coastguard Worker     delegate_raw_pointer->Cancel();
132*14675a02SAndroid Build Coastguard Worker   };
133*14675a02SAndroid Build Coastguard Worker   FCP_RETURN_IF_ERROR(
134*14675a02SAndroid Build Coastguard Worker       interruptible_runner_->Run(tflite_runnable, abort_tflite));
135*14675a02SAndroid Build Coastguard Worker   // handles output tensors
136*14675a02SAndroid Build Coastguard Worker   return ConstructOutputs();
137*14675a02SAndroid Build Coastguard Worker }
138*14675a02SAndroid Build Coastguard Worker 
ConvertTfLiteStatus(TfLiteStatus status)139*14675a02SAndroid Build Coastguard Worker absl::Status TfLiteWrapper::ConvertTfLiteStatus(TfLiteStatus status) {
140*14675a02SAndroid Build Coastguard Worker   switch (status) {
141*14675a02SAndroid Build Coastguard Worker     case kTfLiteOk:
142*14675a02SAndroid Build Coastguard Worker       return absl::OkStatus();
143*14675a02SAndroid Build Coastguard Worker     case kTfLiteError: {
144*14675a02SAndroid Build Coastguard Worker       // TfLite doesn't differentiate the error type when the training is
145*14675a02SAndroid Build Coastguard Worker       // cancelled or an error happened during training. It also doesn't
146*14675a02SAndroid Build Coastguard Worker       // distinguish different error types thrown by Tensorflow. Therefore, we
147*14675a02SAndroid Build Coastguard Worker       // need to check whether the training was cancelled, and record the error
148*14675a02SAndroid Build Coastguard Worker       // message from the ErrorReporter.
149*14675a02SAndroid Build Coastguard Worker       if (tflite::FlexDelegate::HasCancelled(delegate_->data_)) {
150*14675a02SAndroid Build Coastguard Worker         return absl::CancelledError("Training is cancelled.");
151*14675a02SAndroid Build Coastguard Worker       }
152*14675a02SAndroid Build Coastguard Worker       std::string error = error_reporter_->GetFirstErrorMessage();
153*14675a02SAndroid Build Coastguard Worker       if (error.empty()) {
154*14675a02SAndroid Build Coastguard Worker         return absl::InvalidArgumentError("Empty error messages returned.");
155*14675a02SAndroid Build Coastguard Worker       }
156*14675a02SAndroid Build Coastguard Worker       // Use the first error we encountered.
157*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(error);
158*14675a02SAndroid Build Coastguard Worker     }
159*14675a02SAndroid Build Coastguard Worker     case kTfLiteDelegateError:
160*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError("TfLite delegate error.");
161*14675a02SAndroid Build Coastguard Worker     case kTfLiteApplicationError:
162*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
163*14675a02SAndroid Build Coastguard Worker           "An error in applying a delegate due to incompatibility between "
164*14675a02SAndroid Build Coastguard Worker           "runtime and delegate");
165*14675a02SAndroid Build Coastguard Worker     case kTfLiteDelegateDataNotFound:
166*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
167*14675a02SAndroid Build Coastguard Worker           "Serialized delegate data not being found");
168*14675a02SAndroid Build Coastguard Worker     case kTfLiteDelegateDataWriteError:
169*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
170*14675a02SAndroid Build Coastguard Worker           "Data-writing issues in delegate serialization");
171*14675a02SAndroid Build Coastguard Worker     case kTfLiteDelegateDataReadError:
172*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
173*14675a02SAndroid Build Coastguard Worker           "Data-reading issues in delegate serialization.");
174*14675a02SAndroid Build Coastguard Worker     case kTfLiteUnresolvedOps:
175*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(
176*14675a02SAndroid Build Coastguard Worker           "The TF Lite model has ops that cannot be resolved at runtime.");
177*14675a02SAndroid Build Coastguard Worker     default:
178*14675a02SAndroid Build Coastguard Worker       return absl::InternalError("Unexpected TfLiteStatus.");
179*14675a02SAndroid Build Coastguard Worker   }
180*14675a02SAndroid Build Coastguard Worker }
181*14675a02SAndroid Build Coastguard Worker 
ConstructOutputs()182*14675a02SAndroid Build Coastguard Worker absl::StatusOr<OutputTensors> TfLiteWrapper::ConstructOutputs() {
183*14675a02SAndroid Build Coastguard Worker   if (interpreter_->outputs().size() != output_names_.size()) {
184*14675a02SAndroid Build Coastguard Worker     return absl::InvalidArgumentError(
185*14675a02SAndroid Build Coastguard Worker         absl::StrFormat("The number of output tensors is wrong. Expected: %d, "
186*14675a02SAndroid Build Coastguard Worker                         "Returned by TFLite interpreter: %d",
187*14675a02SAndroid Build Coastguard Worker                         output_names_.size(), interpreter_->outputs().size()));
188*14675a02SAndroid Build Coastguard Worker   }
189*14675a02SAndroid Build Coastguard Worker   OutputTensors output_tensors;
190*14675a02SAndroid Build Coastguard Worker   // The order of the output tensors should match the order of output tensor
191*14675a02SAndroid Build Coastguard Worker   // names.
192*14675a02SAndroid Build Coastguard Worker   for (int output_tensor_index : interpreter_->outputs()) {
193*14675a02SAndroid Build Coastguard Worker     auto tensor = tflite::flex::CreateTfTensorFromTfLiteTensor(
194*14675a02SAndroid Build Coastguard Worker         interpreter_->tensor(output_tensor_index));
195*14675a02SAndroid Build Coastguard Worker     if (!tensor.ok()) {
196*14675a02SAndroid Build Coastguard Worker #if TF_GRAPH_DEF_VERSION < 1467
197*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(tensor.status().error_message());
198*14675a02SAndroid Build Coastguard Worker #else
199*14675a02SAndroid Build Coastguard Worker       return absl::InvalidArgumentError(tensor.status().message());
200*14675a02SAndroid Build Coastguard Worker #endif
201*14675a02SAndroid Build Coastguard Worker     }
202*14675a02SAndroid Build Coastguard Worker     output_tensors.output_tensors.push_back(*tensor);
203*14675a02SAndroid Build Coastguard Worker   }
204*14675a02SAndroid Build Coastguard Worker   output_tensors.output_tensor_names = output_names_;
205*14675a02SAndroid Build Coastguard Worker   return output_tensors;
206*14675a02SAndroid Build Coastguard Worker }
207*14675a02SAndroid Build Coastguard Worker 
208*14675a02SAndroid Build Coastguard Worker }  // namespace engine
209*14675a02SAndroid Build Coastguard Worker }  // namespace client
210*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
211