xref: /aosp_15_r20/external/federated-compute/fcp/client/engine/tflite_wrapper.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2021 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "fcp/client/engine/tflite_wrapper.h"
17 
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22 
23 #include "google/protobuf/any.pb.h"
24 #include "absl/status/status.h"
25 #include "absl/status/statusor.h"
26 #include "absl/strings/str_format.h"
27 #include "fcp/base/monitoring.h"
28 #include "tensorflow/core/public/version.h"
29 #include "tensorflow/lite/delegates/flex/util.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/interpreter_builder.h"
32 #include "tensorflow/lite/kernels/register.h"
33 #include "tensorflow/lite/model_builder.h"
34 #include "tensorflow/lite/string_util.h"
35 
36 namespace fcp {
37 namespace client {
38 namespace engine {
39 
40 using ::tflite::ops::builtin::BuiltinOpResolver;
41 
42 namespace {
43 
AssignStringInput(int index,const std::string & value,tflite::Interpreter * interpreter)44 absl::Status AssignStringInput(int index, const std::string& value,
45                                tflite::Interpreter* interpreter) {
46   TfLiteTensor* tensor = interpreter->tensor(index);
47   if (tensor->type != kTfLiteString) {
48     return absl::InvalidArgumentError("Input tensor is not a string tensor.");
49   }
50 
51   tflite::DynamicBuffer buf;
52   buf.AddString(value.data(), value.length());
53   buf.WriteToTensor(tensor, nullptr);
54   return absl::OkStatus();
55 }
56 
57 }  // anonymous namespace
58 
Create(const std::string & model,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,LogManager * log_manager,std::unique_ptr<absl::flat_hash_map<std::string,std::string>> inputs,std::vector<std::string> output_names,const TfLiteInterpreterOptions & interpreter_options,int32_t num_threads)59 absl::StatusOr<std::unique_ptr<TfLiteWrapper>> TfLiteWrapper::Create(
60     const std::string& model, std::function<bool()> should_abort,
61     const InterruptibleRunner::TimingConfig& timing_config,
62     LogManager* log_manager,
63     std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
64     std::vector<std::string> output_names,
65     const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads) {
66   std::unique_ptr<tflite::FlatBufferModel> flat_buffer_model =
67       tflite::FlatBufferModel::BuildFromBuffer(model.c_str(), model.size());
68   if (flat_buffer_model == nullptr) {
69     return absl::InvalidArgumentError("Failed to build FlatBufferModel.");
70   }
71   // The training delegate needs to be created before the interpreter.
72   auto delegate = tflite::FlexDelegate::Create();
73   auto error_reporter = std::make_unique<CachingErrorReporter>();
74   auto interpreter = std::make_unique<tflite::Interpreter>();
75 
76   if (tflite::InterpreterBuilder(
77           flat_buffer_model->GetModel(), BuiltinOpResolver(),
78           error_reporter.get())(&interpreter) != kTfLiteOk) {
79     return absl::InvalidArgumentError(
80         absl::StrCat("Failed to initiate interpreter: ",
81                      error_reporter->GetFirstErrorMessage()));
82   }
83   interpreter->SetNumThreads(num_threads);
84   if (interpreter->ModifyGraphWithDelegate(delegate.get()) != kTfLiteOk) {
85     return absl::InvalidArgumentError(
86         absl::StrCat("Failed to modify graph with TrainingFlexDelegate: ",
87                      error_reporter->GetFirstErrorMessage()));
88   }
89   if (interpreter->AllocateTensors() != kTfLiteOk) {
90     return absl::InvalidArgumentError(
91         absl::StrCat("Failed to allocate tensors: ",
92                      error_reporter->GetFirstErrorMessage()));
93   }
94   interpreter->SetCancellationFunction(delegate->data_,
95                                        tflite::FlexDelegate::HasCancelled);
96   for (const auto& input : interpreter->inputs()) {
97     std::string key = interpreter->GetInputName(input);
98     if (inputs->find(key) == inputs->end()) {
99       return absl::InvalidArgumentError("Unexpected input tensor.");
100     }
101     FCP_RETURN_IF_ERROR(
102         AssignStringInput(input, inputs->at(key), interpreter.get()));
103   }
104   // Create an InterruptibleRunner to execute TF calls in a background thread,
105   // allowing us to abort them if need be.
106   auto runner = std::make_unique<InterruptibleRunner>(
107       log_manager, should_abort, timing_config,
108       InterruptibleRunner::DiagnosticsConfig{
109           .interrupted =
110               ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
111           .interrupt_timeout = ProdDiagCode::
112               BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
113           .interrupted_extended = ProdDiagCode::
114               BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
115           .interrupt_timeout_extended = ProdDiagCode::
116               BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
117   return absl::WrapUnique(
118       new TfLiteWrapper(std::move(flat_buffer_model), std::move(error_reporter),
119                         std::move(delegate), std::move(interpreter),
120                         std::move(runner), std::move(output_names)));
121 }
122 
Run()123 absl::StatusOr<OutputTensors> TfLiteWrapper::Run() {
124   auto* interpreter_raw_pointer = interpreter_.get();
125   auto tflite_runnable = [interpreter_raw_pointer, this]() {
126     return ConvertTfLiteStatus(interpreter_raw_pointer->Invoke());
127   };
128   auto* delegate_raw_pointer =
129       static_cast<tflite::FlexDelegate*>(delegate_->data_);
130   auto abort_tflite = [delegate_raw_pointer]() {
131     delegate_raw_pointer->Cancel();
132   };
133   FCP_RETURN_IF_ERROR(
134       interruptible_runner_->Run(tflite_runnable, abort_tflite));
135   // handles output tensors
136   return ConstructOutputs();
137 }
138 
ConvertTfLiteStatus(TfLiteStatus status)139 absl::Status TfLiteWrapper::ConvertTfLiteStatus(TfLiteStatus status) {
140   switch (status) {
141     case kTfLiteOk:
142       return absl::OkStatus();
143     case kTfLiteError: {
144       // TfLite doesn't differentiate the error type when the training is
145       // cancelled or an error happened during training. It also doesn't
146       // distinguish different error types thrown by Tensorflow. Therefore, we
147       // need to check whether the training was cancelled, and record the error
148       // message from the ErrorReporter.
149       if (tflite::FlexDelegate::HasCancelled(delegate_->data_)) {
150         return absl::CancelledError("Training is cancelled.");
151       }
152       std::string error = error_reporter_->GetFirstErrorMessage();
153       if (error.empty()) {
154         return absl::InvalidArgumentError("Empty error messages returned.");
155       }
156       // Use the first error we encountered.
157       return absl::InvalidArgumentError(error);
158     }
159     case kTfLiteDelegateError:
160       return absl::InvalidArgumentError("TfLite delegate error.");
161     case kTfLiteApplicationError:
162       return absl::InvalidArgumentError(
163           "An error in applying a delegate due to incompatibility between "
164           "runtime and delegate");
165     case kTfLiteDelegateDataNotFound:
166       return absl::InvalidArgumentError(
167           "Serialized delegate data not being found");
168     case kTfLiteDelegateDataWriteError:
169       return absl::InvalidArgumentError(
170           "Data-writing issues in delegate serialization");
171     case kTfLiteDelegateDataReadError:
172       return absl::InvalidArgumentError(
173           "Data-reading issues in delegate serialization.");
174     case kTfLiteUnresolvedOps:
175       return absl::InvalidArgumentError(
176           "The TF Lite model has ops that cannot be resolved at runtime.");
177     default:
178       return absl::InternalError("Unexpected TfLiteStatus.");
179   }
180 }
181 
ConstructOutputs()182 absl::StatusOr<OutputTensors> TfLiteWrapper::ConstructOutputs() {
183   if (interpreter_->outputs().size() != output_names_.size()) {
184     return absl::InvalidArgumentError(
185         absl::StrFormat("The number of output tensors is wrong. Expected: %d, "
186                         "Returned by TFLite interpreter: %d",
187                         output_names_.size(), interpreter_->outputs().size()));
188   }
189   OutputTensors output_tensors;
190   // The order of the output tensors should match the order of output tensor
191   // names.
192   for (int output_tensor_index : interpreter_->outputs()) {
193     auto tensor = tflite::flex::CreateTfTensorFromTfLiteTensor(
194         interpreter_->tensor(output_tensor_index));
195     if (!tensor.ok()) {
196 #if TF_GRAPH_DEF_VERSION < 1467
197       return absl::InvalidArgumentError(tensor.status().error_message());
198 #else
199       return absl::InvalidArgumentError(tensor.status().message());
200 #endif
201     }
202     output_tensors.output_tensors.push_back(*tensor);
203   }
204   output_tensors.output_tensor_names = output_names_;
205   return output_tensors;
206 }
207 
208 }  // namespace engine
209 }  // namespace client
210 }  // namespace fcp
211