1 /*
2 * Copyright 2021 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "fcp/client/engine/tflite_wrapper.h"
17
18 #include <functional>
19 #include <memory>
20 #include <string>
21 #include <utility>
22
23 #include "google/protobuf/any.pb.h"
24 #include "absl/status/status.h"
25 #include "absl/status/statusor.h"
26 #include "absl/strings/str_format.h"
27 #include "fcp/base/monitoring.h"
28 #include "tensorflow/core/public/version.h"
29 #include "tensorflow/lite/delegates/flex/util.h"
30 #include "tensorflow/lite/interpreter.h"
31 #include "tensorflow/lite/interpreter_builder.h"
32 #include "tensorflow/lite/kernels/register.h"
33 #include "tensorflow/lite/model_builder.h"
34 #include "tensorflow/lite/string_util.h"
35
36 namespace fcp {
37 namespace client {
38 namespace engine {
39
40 using ::tflite::ops::builtin::BuiltinOpResolver;
41
42 namespace {
43
AssignStringInput(int index,const std::string & value,tflite::Interpreter * interpreter)44 absl::Status AssignStringInput(int index, const std::string& value,
45 tflite::Interpreter* interpreter) {
46 TfLiteTensor* tensor = interpreter->tensor(index);
47 if (tensor->type != kTfLiteString) {
48 return absl::InvalidArgumentError("Input tensor is not a string tensor.");
49 }
50
51 tflite::DynamicBuffer buf;
52 buf.AddString(value.data(), value.length());
53 buf.WriteToTensor(tensor, nullptr);
54 return absl::OkStatus();
55 }
56
57 } // anonymous namespace
58
Create(const std::string & model,std::function<bool ()> should_abort,const InterruptibleRunner::TimingConfig & timing_config,LogManager * log_manager,std::unique_ptr<absl::flat_hash_map<std::string,std::string>> inputs,std::vector<std::string> output_names,const TfLiteInterpreterOptions & interpreter_options,int32_t num_threads)59 absl::StatusOr<std::unique_ptr<TfLiteWrapper>> TfLiteWrapper::Create(
60 const std::string& model, std::function<bool()> should_abort,
61 const InterruptibleRunner::TimingConfig& timing_config,
62 LogManager* log_manager,
63 std::unique_ptr<absl::flat_hash_map<std::string, std::string>> inputs,
64 std::vector<std::string> output_names,
65 const TfLiteInterpreterOptions& interpreter_options, int32_t num_threads) {
66 std::unique_ptr<tflite::FlatBufferModel> flat_buffer_model =
67 tflite::FlatBufferModel::BuildFromBuffer(model.c_str(), model.size());
68 if (flat_buffer_model == nullptr) {
69 return absl::InvalidArgumentError("Failed to build FlatBufferModel.");
70 }
71 // The training delegate needs to be created before the interpreter.
72 auto delegate = tflite::FlexDelegate::Create();
73 auto error_reporter = std::make_unique<CachingErrorReporter>();
74 auto interpreter = std::make_unique<tflite::Interpreter>();
75
76 if (tflite::InterpreterBuilder(
77 flat_buffer_model->GetModel(), BuiltinOpResolver(),
78 error_reporter.get())(&interpreter) != kTfLiteOk) {
79 return absl::InvalidArgumentError(
80 absl::StrCat("Failed to initiate interpreter: ",
81 error_reporter->GetFirstErrorMessage()));
82 }
83 interpreter->SetNumThreads(num_threads);
84 if (interpreter->ModifyGraphWithDelegate(delegate.get()) != kTfLiteOk) {
85 return absl::InvalidArgumentError(
86 absl::StrCat("Failed to modify graph with TrainingFlexDelegate: ",
87 error_reporter->GetFirstErrorMessage()));
88 }
89 if (interpreter->AllocateTensors() != kTfLiteOk) {
90 return absl::InvalidArgumentError(
91 absl::StrCat("Failed to allocate tensors: ",
92 error_reporter->GetFirstErrorMessage()));
93 }
94 interpreter->SetCancellationFunction(delegate->data_,
95 tflite::FlexDelegate::HasCancelled);
96 for (const auto& input : interpreter->inputs()) {
97 std::string key = interpreter->GetInputName(input);
98 if (inputs->find(key) == inputs->end()) {
99 return absl::InvalidArgumentError("Unexpected input tensor.");
100 }
101 FCP_RETURN_IF_ERROR(
102 AssignStringInput(input, inputs->at(key), interpreter.get()));
103 }
104 // Create an InterruptibleRunner to execute TF calls in a background thread,
105 // allowing us to abort them if need be.
106 auto runner = std::make_unique<InterruptibleRunner>(
107 log_manager, should_abort, timing_config,
108 InterruptibleRunner::DiagnosticsConfig{
109 .interrupted =
110 ProdDiagCode::BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION,
111 .interrupt_timeout = ProdDiagCode::
112 BACKGROUND_TRAINING_INTERRUPT_TF_EXECUTION_TIMED_OUT,
113 .interrupted_extended = ProdDiagCode::
114 BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_COMPLETED,
115 .interrupt_timeout_extended = ProdDiagCode::
116 BACKGROUND_TRAINING_INTERRUPT_TF_EXTENDED_EXECUTION_TIMED_OUT});
117 return absl::WrapUnique(
118 new TfLiteWrapper(std::move(flat_buffer_model), std::move(error_reporter),
119 std::move(delegate), std::move(interpreter),
120 std::move(runner), std::move(output_names)));
121 }
122
Run()123 absl::StatusOr<OutputTensors> TfLiteWrapper::Run() {
124 auto* interpreter_raw_pointer = interpreter_.get();
125 auto tflite_runnable = [interpreter_raw_pointer, this]() {
126 return ConvertTfLiteStatus(interpreter_raw_pointer->Invoke());
127 };
128 auto* delegate_raw_pointer =
129 static_cast<tflite::FlexDelegate*>(delegate_->data_);
130 auto abort_tflite = [delegate_raw_pointer]() {
131 delegate_raw_pointer->Cancel();
132 };
133 FCP_RETURN_IF_ERROR(
134 interruptible_runner_->Run(tflite_runnable, abort_tflite));
135 // handles output tensors
136 return ConstructOutputs();
137 }
138
ConvertTfLiteStatus(TfLiteStatus status)139 absl::Status TfLiteWrapper::ConvertTfLiteStatus(TfLiteStatus status) {
140 switch (status) {
141 case kTfLiteOk:
142 return absl::OkStatus();
143 case kTfLiteError: {
144 // TfLite doesn't differentiate the error type when the training is
145 // cancelled or an error happened during training. It also doesn't
146 // distinguish different error types thrown by Tensorflow. Therefore, we
147 // need to check whether the training was cancelled, and record the error
148 // message from the ErrorReporter.
149 if (tflite::FlexDelegate::HasCancelled(delegate_->data_)) {
150 return absl::CancelledError("Training is cancelled.");
151 }
152 std::string error = error_reporter_->GetFirstErrorMessage();
153 if (error.empty()) {
154 return absl::InvalidArgumentError("Empty error messages returned.");
155 }
156 // Use the first error we encountered.
157 return absl::InvalidArgumentError(error);
158 }
159 case kTfLiteDelegateError:
160 return absl::InvalidArgumentError("TfLite delegate error.");
161 case kTfLiteApplicationError:
162 return absl::InvalidArgumentError(
163 "An error in applying a delegate due to incompatibility between "
164 "runtime and delegate");
165 case kTfLiteDelegateDataNotFound:
166 return absl::InvalidArgumentError(
167 "Serialized delegate data not being found");
168 case kTfLiteDelegateDataWriteError:
169 return absl::InvalidArgumentError(
170 "Data-writing issues in delegate serialization");
171 case kTfLiteDelegateDataReadError:
172 return absl::InvalidArgumentError(
173 "Data-reading issues in delegate serialization.");
174 case kTfLiteUnresolvedOps:
175 return absl::InvalidArgumentError(
176 "The TF Lite model has ops that cannot be resolved at runtime.");
177 default:
178 return absl::InternalError("Unexpected TfLiteStatus.");
179 }
180 }
181
ConstructOutputs()182 absl::StatusOr<OutputTensors> TfLiteWrapper::ConstructOutputs() {
183 if (interpreter_->outputs().size() != output_names_.size()) {
184 return absl::InvalidArgumentError(
185 absl::StrFormat("The number of output tensors is wrong. Expected: %d, "
186 "Returned by TFLite interpreter: %d",
187 output_names_.size(), interpreter_->outputs().size()));
188 }
189 OutputTensors output_tensors;
190 // The order of the output tensors should match the order of output tensor
191 // names.
192 for (int output_tensor_index : interpreter_->outputs()) {
193 auto tensor = tflite::flex::CreateTfTensorFromTfLiteTensor(
194 interpreter_->tensor(output_tensor_index));
195 if (!tensor.ok()) {
196 #if TF_GRAPH_DEF_VERSION < 1467
197 return absl::InvalidArgumentError(tensor.status().error_message());
198 #else
199 return absl::InvalidArgumentError(tensor.status().message());
200 #endif
201 }
202 output_tensors.output_tensors.push_back(*tensor);
203 }
204 output_tensors.output_tensor_names = output_names_;
205 return output_tensors;
206 }
207
208 } // namespace engine
209 } // namespace client
210 } // namespace fcp
211